• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/reduce_stack_fusion.h"
19 #include <functional>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "mindspore/core/ops/array_ops.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "tools/lite_exporter/fetch_content.h"
24 #include "ops/op_name.h"
25 #include "nnacl/op_base.h"
26 
27 namespace mindspore {
28 namespace opt {
Run(const FuncGraphPtr & func_graph)29 bool ReduceStackFusion::Run(const FuncGraphPtr &func_graph) {
30   if (func_graph == nullptr) {
31     MS_LOG(ERROR) << "func_graph is a nullptr, cannot do ReduceStackFusion.";
32     return false;
33   }
34   auto node_list = TopoSort(func_graph->get_return());
35   for (auto &node : node_list) {
36     if (!utils::isa<CNode>(node) || !CheckPrimitiveType(node, prim::kPrimStack)) {
37       continue;
38     }
39     auto stack_cnode = node->cast<CNodePtr>();
40     if (Process(func_graph, stack_cnode) != lite::RET_OK) {
41       MS_LOG(ERROR) << "Do ReduceStackFusion failed.";
42       return false;
43     }
44   }
45   return true;
46 }
47 
Process(const FuncGraphPtr & func_graph,const CNodePtr & stack)48 int ReduceStackFusion::Process(const FuncGraphPtr &func_graph, const CNodePtr &stack) {
49   MS_ASSERT(func_graph != nullptr && stack != nullptr);
50   if (!CheckCanFusion(func_graph, stack)) {
51     return lite::RET_OK;
52   }
53   reduce_prim_->AddAttr(ops::kKeepDims, MakeValue(true));
54   auto manager = func_graph->manager();
55   if (manager == nullptr) {
56     MS_LOG(ERROR) << "Manager is a nullptr.";
57     return lite::RET_NULL_PTR;
58   }
59   if (!manager->Replace(stack, stack->input(1))) {
60     MS_LOG(ERROR) << "do Manager-Replace failed.";
61     return lite::RET_ERROR;
62   }
63   return lite::RET_OK;
64 }
65 
CheckCanFusion(const FuncGraphPtr & func_graph,const CNodePtr & stack)66 bool ReduceStackFusion::CheckCanFusion(const FuncGraphPtr &func_graph, const CNodePtr &stack) {
67   MS_ASSERT(func_graph != nullptr && stack != nullptr);
68   if (IsMarkedTrainOp(stack)) {
69     return false;
70   }
71   if (stack->size() != kInputSizeTwo) {
72     return false;
73   }
74   auto prim = GetCNodePrimitive(stack);
75   MS_CHECK_TRUE_RET(prim != nullptr, false);
76   if (IsQuantParameterNode(prim)) {
77     return false;
78   }
79   auto axis = prim->GetAttr(ops::kAxis) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kAxis));
80   auto input_node = stack->input(1);
81   if (!utils::isa<CNode>(input_node) || !CheckPrimitiveType(input_node, prim::kPrimReduceFusion)) {
82     return false;
83   }
84   auto reduce = input_node->cast<CNodePtr>();
85   return CheckReduce(func_graph, reduce, axis);
86 }
87 
CheckReduce(const FuncGraphPtr & func_graph,const CNodePtr & reduce,int stack_axis)88 bool ReduceStackFusion::CheckReduce(const FuncGraphPtr &func_graph, const CNodePtr &reduce, int stack_axis) {
89   MS_ASSERT(func_graph != nullptr && reduce != nullptr);
90   if (IsMarkedTrainOp(reduce)) {
91     return false;
92   }
93   if (IsMultiOutputTensors(func_graph, reduce)) {
94     return false;
95   }
96   if (reduce->size() < kInputSizeThree || reduce->input(ops::kInputIndex2) == nullptr ||
97       utils::isa<CNode>(reduce->input(ops::kInputIndex2))) {
98     return false;
99   }
100   reduce_prim_ = GetCNodePrimitive(reduce);
101   MS_CHECK_TRUE_RET(reduce_prim_ != nullptr, false);
102   if (IsQuantParameterNode(reduce_prim_)) {
103     return false;
104   }
105   bool keep_dim =
106     reduce_prim_->GetAttr(ops::kKeepDims) == nullptr ? false : GetValue<bool>(reduce_prim_->GetAttr(ops::kKeepDims));
107   if (keep_dim) {
108     return false;
109   }
110   lite::DataInfo data_info;
111   if (lite::FetchConstData(reduce, ops::kInputIndex2, converter::kFmkTypeMs, &data_info, false) != lite::RET_OK) {
112     return false;
113   }
114   if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
115       data_info.data_ptr_ == nullptr) {
116     return false;
117   }
118   auto num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<>());
119   if (num > 1) {
120     return false;
121   }
122   return *(static_cast<int *>(data_info.data_ptr_)) == stack_axis;
123 }
124 }  // namespace opt
125 }  // namespace mindspore
126