1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/adapter/acl/mapper/reduce_fusion_mapper.h"
18 #include <memory>
19 #include <vector>
20 #include <algorithm>
21 #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
22 #include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
23 #include "tools/converter/adapter/acl/common/utils.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "src/common/log_util.h"
26 #include "ops/op_utils.h"
27 #include "ops/auto_generate/gen_lite_ops.h"
28 #include "ops/lp_norm.h"
29 #include "tools/lite_exporter/fetch_content.h"
30
31 namespace mindspore {
32 namespace lite {
33 namespace {
34 constexpr auto kNameReduceMinInputNum = 2;
35 constexpr auto kNameReduceInputNum = 3;
36 } // namespace
37
Mapper(const CNodePtr & cnode)38 STATUS ReduceFusionMapper::Mapper(const CNodePtr &cnode) {
39 ValueNodePtr value_node = nullptr;
40 PrimitivePtr src_prim = nullptr;
41 if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
42 MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
43 return lite::RET_ERROR;
44 }
45 auto attr_val = src_prim->GetAttr(ops::kMode);
46 CHECK_NULL_RETURN(attr_val);
47 int64_t mode = GetValue<int64_t>(attr_val);
48 PrimitivePtr dst_prim = nullptr;
49 if (mode == static_cast<int64_t>(ReduceMode::Reduce_Sum)) {
50 ops::ReduceSum reduce_sum_op;
51 dst_prim = reduce_sum_op.GetPrim();
52 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Mean)) {
53 ops::ReduceMean reduce_mean_op;
54 dst_prim = reduce_mean_op.GetPrim();
55 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Max)) {
56 ops::ReduceMax reduce_max_op;
57 dst_prim = reduce_max_op.GetPrim();
58 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Min)) {
59 ops::ReduceMin reduce_min_op;
60 dst_prim = reduce_min_op.GetPrim();
61 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_All)) {
62 ops::ReduceAll reduce_all;
63 dst_prim = reduce_all.GetPrim();
64 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_L2)) {
65 ops::LpNorm lp_norm_op;
66 auto axes_ptr = src_prim->GetAttr(ops::kAxes);
67 if (axes_ptr != nullptr) {
68 auto axes = GetValue<std::vector<int32_t>>(axes_ptr);
69 std::vector<int64_t> axes_vec;
70 std::transform(axes.begin(), axes.end(), std::back_inserter(axes_vec),
71 [](int32_t x) { return static_cast<int64_t>(x); });
72 lp_norm_op.set_axis(axes_vec);
73 }
74 dst_prim = lp_norm_op.GetPrim();
75 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_L1)) {
76 ops::LpNorm lp_norm_op;
77 auto axes_ptr = src_prim->GetAttr(ops::kAxes);
78 if (axes_ptr != nullptr) {
79 auto axes = GetValue<std::vector<int32_t>>(axes_ptr);
80 std::vector<int64_t> axes_vec;
81 std::transform(axes.begin(), axes.end(), std::back_inserter(axes_vec),
82 [](int32_t x) { return static_cast<int64_t>(x); });
83 lp_norm_op.set_axis(axes_vec);
84 }
85 auto keep_dims_ptr = src_prim->GetAttr(ops::kKeepDims);
86 if (keep_dims_ptr != nullptr) {
87 auto keep_dims = GetValue<bool>(keep_dims_ptr);
88 lp_norm_op.set_keep_dims(keep_dims);
89 }
90 lp_norm_op.set_p(1);
91 dst_prim = lp_norm_op.GetPrim();
92 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Prod)) {
93 dst_prim = std::make_shared<acl::DynamicReduceProd>();
94 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Log_Sum)) {
95 dst_prim = std::make_shared<acl::ReduceLogSum>();
96 } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Log_Sum_Exp)) {
97 dst_prim = std::make_shared<acl::ReduceLogSumExp>();
98 } else {
99 MS_LOG(ERROR) << "Not support reduce mode " << static_cast<int64_t>(mode);
100 return RET_ERROR;
101 }
102 CHECK_NULL_RETURN(dst_prim);
103 dst_prim->SetAttrs(src_prim->attrs());
104 value_node->set_value(dst_prim);
105 if (mode == static_cast<int64_t>(ReduceMode::Reduce_Mean)) {
106 return lite::RET_OK;
107 }
108 if (AdjustInput(cnode, dst_prim) != RET_OK) {
109 MS_LOG(ERROR) << "Adjust reduce input failed.";
110 return lite::RET_ERROR;
111 }
112 return RET_OK;
113 }
114
GetAxes(const CNodePtr & cnode,int64_t mode,std::vector<int64_t> * axes,ParameterPtr axes_param,DataInfo data_info)115 STATUS GetAxes(const CNodePtr &cnode, int64_t mode, std::vector<int64_t> *axes, ParameterPtr axes_param,
116 DataInfo data_info) {
117 int data_len = 0;
118 std::vector<int> data_int;
119 std::vector<int64_t> data_int64;
120 if (data_info.data_type_ == kNumberTypeInt64) {
121 data_int64 = acl::GetInt64ParameterData(axes_param);
122 data_len = data_int64.size();
123 } else {
124 data_int = acl::GetIntParameterData(axes_param);
125 data_len = data_int.size();
126 }
127 if (cnode->size() == kNameReduceInputNum && mode == static_cast<int64_t>(ReduceMode::Reduce_Max) && data_len == 0) {
128 auto abstract = opt::GetCNodeInputAbstract(cnode, 1);
129 if (abstract == nullptr) {
130 MS_LOG(ERROR) << "GetCNodeInputAbstract in reduce_fusion!";
131 return lite::RET_ERROR;
132 }
133 std::vector<int64_t> shape = {};
134 if (opt::FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
135 MS_LOG(ERROR) << "FetchShapeFromAbstract failed!";
136 return lite::RET_ERROR;
137 }
138 int rank = shape.size();
139 if (data_info.data_type_ == kNumberTypeInt64) {
140 for (int dim = 0; dim < rank; dim++) {
141 data_int64.push_back(dim);
142 }
143 } else {
144 for (int dim = 0; dim < rank; dim++) {
145 data_int.push_back(dim);
146 }
147 }
148 }
149 if (data_info.data_type_ == kNumberTypeInt64) {
150 *axes = data_int64;
151 } else {
152 std::transform(data_int.begin(), data_int.end(), std::back_inserter(*axes),
153 [](int32_t n) -> int64_t { return static_cast<int64_t>(n); });
154 }
155 return lite::RET_OK;
156 }
157
AdjustInput(const CNodePtr & cnode,const PrimitivePtr & prim)158 STATUS ReduceFusionMapper::AdjustInput(const CNodePtr &cnode, const PrimitivePtr &prim) {
159 MS_ASSERT(cnode != nullptr && prim != nullptr);
160 auto attr_val = prim->GetAttr(ops::kMode);
161 CHECK_NULL_RETURN(attr_val);
162 int64_t mode = GetValue<int64_t>(attr_val);
163 if (cnode->size() == kNameReduceInputNum && mode == static_cast<int64_t>(ReduceMode::Reduce_Prod)) {
164 auto axes_ptr = prim->GetAttr(ops::kAxes);
165 if (axes_ptr != nullptr) {
166 auto axes = GetValue<std::vector<int32_t>>(axes_ptr);
167 if (axes.empty()) {
168 auto new_inputs = {cnode->input(0), cnode->input(1)};
169 cnode->set_inputs(new_inputs);
170 }
171 }
172 }
173 if (cnode->size() == kNameReduceMinInputNum) {
174 auto func_graph = cnode->func_graph();
175 CHECK_NULL_RETURN(func_graph);
176 auto attr_name = mode != static_cast<int64_t>(ReduceMode::Reduce_Prod) ? ops::kAxes : ops::kKeepDims;
177 auto ret = mode != static_cast<int64_t>(ReduceMode::Reduce_Prod)
178 ? AddIntVecAttrToInput(func_graph, cnode, prim, attr_name)
179 : AddIntAttrToInput(func_graph, cnode, prim, attr_name, true);
180 if (ret != lite::RET_OK) {
181 MS_LOG(ERROR) << "Add attr " << attr_name << " failed for cnode: " << cnode->fullname_with_scope();
182 return lite::RET_ERROR;
183 }
184 return lite::RET_OK;
185 }
186
187 auto axes_input = cnode->input(kNameReduceInputNum - 1);
188 CHECK_NULL_RETURN(axes_input);
189 if (!utils::isa<ParameterPtr>(axes_input)) {
190 MS_LOG(ERROR) << "The reduce node is not parameter.";
191 return lite::RET_ERROR;
192 }
193 ParameterPtr axes_param = axes_input->cast<ParameterPtr>();
194 CHECK_NULL_RETURN(axes_param);
195 DataInfo data_info;
196 if (FetchFromDefaultParam(axes_param, converter::kFmkTypeMs, &data_info, true) != RET_OK) {
197 MS_LOG(ERROR) << "fetch information from default param failed!";
198 return lite::RET_ERROR;
199 }
200
201 std::vector<int64_t> axes;
202 auto ret = GetAxes(cnode, mode, &axes, axes_param, data_info);
203 if (ret != lite::RET_OK) {
204 MS_LOG(ERROR) << "Get axes failed! ret:" << ret << "!";
205 return ret;
206 }
207 ValueNodePtr value_node = NewValueNode<std::vector<int64_t>>(axes);
208 std::vector<int64_t> shape_vec_shape = {};
209 auto abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
210 value_node->set_abstract(abstract);
211 CHECK_NULL_RETURN(value_node);
212 cnode->set_input(kNameReduceInputNum - 1, value_node);
213 return lite::RET_OK;
214 }
215
216 REGISTER_PRIMITIVE_MAPPER(kNameReduceFusion, ReduceFusionMapper)
217 } // namespace lite
218 } // namespace mindspore
219