• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/adapter/acl/mapper/reduce_fusion_mapper.h"
18 #include <memory>
19 #include <vector>
20 #include <algorithm>
21 #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
22 #include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
23 #include "tools/converter/adapter/acl/common/utils.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "src/common/log_util.h"
26 #include "ops/op_utils.h"
27 #include "ops/auto_generate/gen_lite_ops.h"
28 #include "ops/lp_norm.h"
29 #include "tools/lite_exporter/fetch_content.h"
30 
31 namespace mindspore {
32 namespace lite {
33 namespace {
34 constexpr auto kNameReduceMinInputNum = 2;
35 constexpr auto kNameReduceInputNum = 3;
36 }  // namespace
37 
Mapper(const CNodePtr & cnode)38 STATUS ReduceFusionMapper::Mapper(const CNodePtr &cnode) {
39   ValueNodePtr value_node = nullptr;
40   PrimitivePtr src_prim = nullptr;
41   if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
42     MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
43     return lite::RET_ERROR;
44   }
45   auto attr_val = src_prim->GetAttr(ops::kMode);
46   CHECK_NULL_RETURN(attr_val);
47   int64_t mode = GetValue<int64_t>(attr_val);
48   PrimitivePtr dst_prim = nullptr;
49   if (mode == static_cast<int64_t>(ReduceMode::Reduce_Sum)) {
50     ops::ReduceSum reduce_sum_op;
51     dst_prim = reduce_sum_op.GetPrim();
52   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Mean)) {
53     ops::ReduceMean reduce_mean_op;
54     dst_prim = reduce_mean_op.GetPrim();
55   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Max)) {
56     ops::ReduceMax reduce_max_op;
57     dst_prim = reduce_max_op.GetPrim();
58   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Min)) {
59     ops::ReduceMin reduce_min_op;
60     dst_prim = reduce_min_op.GetPrim();
61   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_All)) {
62     ops::ReduceAll reduce_all;
63     dst_prim = reduce_all.GetPrim();
64   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_L2)) {
65     ops::LpNorm lp_norm_op;
66     auto axes_ptr = src_prim->GetAttr(ops::kAxes);
67     if (axes_ptr != nullptr) {
68       auto axes = GetValue<std::vector<int32_t>>(axes_ptr);
69       std::vector<int64_t> axes_vec;
70       std::transform(axes.begin(), axes.end(), std::back_inserter(axes_vec),
71                      [](int32_t x) { return static_cast<int64_t>(x); });
72       lp_norm_op.set_axis(axes_vec);
73     }
74     dst_prim = lp_norm_op.GetPrim();
75   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_L1)) {
76     ops::LpNorm lp_norm_op;
77     auto axes_ptr = src_prim->GetAttr(ops::kAxes);
78     if (axes_ptr != nullptr) {
79       auto axes = GetValue<std::vector<int32_t>>(axes_ptr);
80       std::vector<int64_t> axes_vec;
81       std::transform(axes.begin(), axes.end(), std::back_inserter(axes_vec),
82                      [](int32_t x) { return static_cast<int64_t>(x); });
83       lp_norm_op.set_axis(axes_vec);
84     }
85     auto keep_dims_ptr = src_prim->GetAttr(ops::kKeepDims);
86     if (keep_dims_ptr != nullptr) {
87       auto keep_dims = GetValue<bool>(keep_dims_ptr);
88       lp_norm_op.set_keep_dims(keep_dims);
89     }
90     lp_norm_op.set_p(1);
91     dst_prim = lp_norm_op.GetPrim();
92   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Prod)) {
93     dst_prim = std::make_shared<acl::DynamicReduceProd>();
94   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Log_Sum)) {
95     dst_prim = std::make_shared<acl::ReduceLogSum>();
96   } else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Log_Sum_Exp)) {
97     dst_prim = std::make_shared<acl::ReduceLogSumExp>();
98   } else {
99     MS_LOG(ERROR) << "Not support reduce mode " << static_cast<int64_t>(mode);
100     return RET_ERROR;
101   }
102   CHECK_NULL_RETURN(dst_prim);
103   dst_prim->SetAttrs(src_prim->attrs());
104   value_node->set_value(dst_prim);
105   if (mode == static_cast<int64_t>(ReduceMode::Reduce_Mean)) {
106     return lite::RET_OK;
107   }
108   if (AdjustInput(cnode, dst_prim) != RET_OK) {
109     MS_LOG(ERROR) << "Adjust reduce input failed.";
110     return lite::RET_ERROR;
111   }
112   return RET_OK;
113 }
114 
GetAxes(const CNodePtr & cnode,int64_t mode,std::vector<int64_t> * axes,ParameterPtr axes_param,DataInfo data_info)115 STATUS GetAxes(const CNodePtr &cnode, int64_t mode, std::vector<int64_t> *axes, ParameterPtr axes_param,
116                DataInfo data_info) {
117   int data_len = 0;
118   std::vector<int> data_int;
119   std::vector<int64_t> data_int64;
120   if (data_info.data_type_ == kNumberTypeInt64) {
121     data_int64 = acl::GetInt64ParameterData(axes_param);
122     data_len = data_int64.size();
123   } else {
124     data_int = acl::GetIntParameterData(axes_param);
125     data_len = data_int.size();
126   }
127   if (cnode->size() == kNameReduceInputNum && mode == static_cast<int64_t>(ReduceMode::Reduce_Max) && data_len == 0) {
128     auto abstract = opt::GetCNodeInputAbstract(cnode, 1);
129     if (abstract == nullptr) {
130       MS_LOG(ERROR) << "GetCNodeInputAbstract in reduce_fusion!";
131       return lite::RET_ERROR;
132     }
133     std::vector<int64_t> shape = {};
134     if (opt::FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
135       MS_LOG(ERROR) << "FetchShapeFromAbstract failed!";
136       return lite::RET_ERROR;
137     }
138     int rank = shape.size();
139     if (data_info.data_type_ == kNumberTypeInt64) {
140       for (int dim = 0; dim < rank; dim++) {
141         data_int64.push_back(dim);
142       }
143     } else {
144       for (int dim = 0; dim < rank; dim++) {
145         data_int.push_back(dim);
146       }
147     }
148   }
149   if (data_info.data_type_ == kNumberTypeInt64) {
150     *axes = data_int64;
151   } else {
152     std::transform(data_int.begin(), data_int.end(), std::back_inserter(*axes),
153                    [](int32_t n) -> int64_t { return static_cast<int64_t>(n); });
154   }
155   return lite::RET_OK;
156 }
157 
AdjustInput(const CNodePtr & cnode,const PrimitivePtr & prim)158 STATUS ReduceFusionMapper::AdjustInput(const CNodePtr &cnode, const PrimitivePtr &prim) {
159   MS_ASSERT(cnode != nullptr && prim != nullptr);
160   auto attr_val = prim->GetAttr(ops::kMode);
161   CHECK_NULL_RETURN(attr_val);
162   int64_t mode = GetValue<int64_t>(attr_val);
163   if (cnode->size() == kNameReduceInputNum && mode == static_cast<int64_t>(ReduceMode::Reduce_Prod)) {
164     auto axes_ptr = prim->GetAttr(ops::kAxes);
165     if (axes_ptr != nullptr) {
166       auto axes = GetValue<std::vector<int32_t>>(axes_ptr);
167       if (axes.empty()) {
168         auto new_inputs = {cnode->input(0), cnode->input(1)};
169         cnode->set_inputs(new_inputs);
170       }
171     }
172   }
173   if (cnode->size() == kNameReduceMinInputNum) {
174     auto func_graph = cnode->func_graph();
175     CHECK_NULL_RETURN(func_graph);
176     auto attr_name = mode != static_cast<int64_t>(ReduceMode::Reduce_Prod) ? ops::kAxes : ops::kKeepDims;
177     auto ret = mode != static_cast<int64_t>(ReduceMode::Reduce_Prod)
178                  ? AddIntVecAttrToInput(func_graph, cnode, prim, attr_name)
179                  : AddIntAttrToInput(func_graph, cnode, prim, attr_name, true);
180     if (ret != lite::RET_OK) {
181       MS_LOG(ERROR) << "Add attr " << attr_name << " failed for cnode: " << cnode->fullname_with_scope();
182       return lite::RET_ERROR;
183     }
184     return lite::RET_OK;
185   }
186 
187   auto axes_input = cnode->input(kNameReduceInputNum - 1);
188   CHECK_NULL_RETURN(axes_input);
189   if (!utils::isa<ParameterPtr>(axes_input)) {
190     MS_LOG(ERROR) << "The reduce node is not parameter.";
191     return lite::RET_ERROR;
192   }
193   ParameterPtr axes_param = axes_input->cast<ParameterPtr>();
194   CHECK_NULL_RETURN(axes_param);
195   DataInfo data_info;
196   if (FetchFromDefaultParam(axes_param, converter::kFmkTypeMs, &data_info, true) != RET_OK) {
197     MS_LOG(ERROR) << "fetch information from default param failed!";
198     return lite::RET_ERROR;
199   }
200 
201   std::vector<int64_t> axes;
202   auto ret = GetAxes(cnode, mode, &axes, axes_param, data_info);
203   if (ret != lite::RET_OK) {
204     MS_LOG(ERROR) << "Get axes failed! ret:" << ret << "!";
205     return ret;
206   }
207   ValueNodePtr value_node = NewValueNode<std::vector<int64_t>>(axes);
208   std::vector<int64_t> shape_vec_shape = {};
209   auto abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
210   value_node->set_abstract(abstract);
211   CHECK_NULL_RETURN(value_node);
212   cnode->set_input(kNameReduceInputNum - 1, value_node);
213   return lite::RET_OK;
214 }
215 
216 REGISTER_PRIMITIVE_MAPPER(kNameReduceFusion, ReduceFusionMapper)
217 }  // namespace lite
218 }  // namespace mindspore
219