• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/scale_activation_fusion.h"
19 #include <memory>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "ops/fusion/activation.h"
22 #include "ops/fusion/scale_fusion.h"
23 #include "ops/op_utils.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "nnacl/op_base.h"
26 
27 namespace mindspore::opt {
DefinePattern() const28 const BaseRef ScaleActivationFusion::DefinePattern() const {
29   auto is_scale = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
30   MS_CHECK_TRUE_RET(is_scale != nullptr, {});
31   auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
32   MS_CHECK_TRUE_RET(is_activation != nullptr, {});
33   return VectorRef({is_activation, is_scale});
34 }
35 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const36 const AnfNodePtr ScaleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
37                                                 const EquivPtr &) const {
38   if (func_graph == nullptr || node == nullptr) {
39     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
40     return nullptr;
41   }
42   auto act_node = node->cast<CNodePtr>();
43   MS_CHECK_TRUE_RET(act_node != nullptr, nullptr);
44   if (!CheckPrimitiveType(act_node, prim::kPrimActivation) || IsMarkedTrainOp(act_node)) {
45     return nullptr;
46   }
47   MS_CHECK_TRUE_RET(act_node->size() == kInputSizeTwo, nullptr);
48   auto act_prim = ops::GetOperator<mindspore::ops::Activation>(act_node->input(FIRST_INPUT));
49   MS_CHECK_TRUE_RET(act_prim != nullptr, nullptr);
50   auto act_prim_c = act_prim->GetPrim();
51   MS_CHECK_TRUE_RET(act_prim_c != nullptr && act_prim_c->GetAttr(ops::kActivationType) != nullptr, nullptr);
52   if (act_prim->get_activation_type() != mindspore::RELU && act_prim->get_activation_type() != mindspore::RELU6) {
53     return nullptr;
54   }
55 
56   auto scale_node = act_node->input(SECOND_INPUT);
57   MS_CHECK_TRUE_RET(scale_node != nullptr, nullptr);
58   auto scale_cnode = scale_node->cast<CNodePtr>();
59   MS_CHECK_TRUE_RET(scale_cnode != nullptr, nullptr);
60   if (IsMarkedTrainOp(scale_cnode) || IsMultiOutputTensors(func_graph, scale_cnode)) {
61     return nullptr;
62   }
63   auto scale_prim = ops::GetOperator<ops::ScaleFusion>(scale_cnode->input(FIRST_INPUT));
64   MS_ASSERT(scale_prim != nullptr);
65   auto scale_prim_c = scale_prim->GetPrim();
66   MS_CHECK_TRUE_RET(scale_prim_c != nullptr, nullptr);
67   ActivationType act_type = act_prim->get_activation_type();
68   if (scale_prim_c->GetAttr(ops::kActivationType) != nullptr && scale_prim->get_activation_type() != NO_ACTIVATION) {
69     auto scale_act = scale_prim->get_activation_type();
70     MS_CHECK_TRUE_RET(scale_act == RELU || scale_act == RELU6, nullptr);
71     act_type = scale_act == RELU6 ? RELU6 : act_type;
72   }
73   (void)scale_prim_c->AddAttr(ops::kActivationType, MakeValue<int64_t>(static_cast<int64_t>(act_type)));
74   return scale_node;
75 }
76 }  // namespace mindspore::opt
77