1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
18 #include <map>
19 #include <vector>
20 #include "tools/converter/adapter/acl/common/utils.h"
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
23 #include "ir/graph_utils.h"
24 #include "include/errorcode.h"
25 #include "include/registry/converter_context.h"
26 #include "ops/op_utils.h"
27 #include "ops/fusion/avg_pool_fusion.h"
28 #include "ops/fusion/max_pool_fusion.h"
29 #include "plugin/device/cpu/kernel/nnacl/op_base.h"
30 #include "src/common/log_util.h"
31
32 namespace mindspore {
33 namespace lite {
34 namespace {
35 constexpr auto kCommonAttrValueNum = 2;
36 constexpr auto kNamePaddingMode = "padding_mode";
37 constexpr auto kNameCeilMode = "ceil_mode";
38 } // namespace
39
Mapper(const CNodePtr & cnode)40 STATUS PrimitiveMapper::Mapper(const CNodePtr &cnode) { return lite::RET_OK; }
41
GetValueNodeAndPrimFromCnode(const CNodePtr & cnode,ValueNodePtr * value_node,PrimitivePtr * prim_ptr) const42 STATUS PrimitiveMapper::GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, ValueNodePtr *value_node,
43 PrimitivePtr *prim_ptr) const {
44 CHECK_NULL_RETURN(cnode);
45 CHECK_NULL_RETURN(value_node);
46 CHECK_NULL_RETURN(prim_ptr);
47 CHECK_NULL_RETURN(cnode->input(0));
48
49 *value_node = cnode->input(0)->cast<ValueNodePtr>();
50 if (*value_node == nullptr) {
51 MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] is nullptr.";
52 return lite::RET_ERROR;
53 }
54 *prim_ptr = GetValueNode<PrimitivePtr>(*value_node);
55 if (*prim_ptr == nullptr) {
56 MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] cast to primitive failed.";
57 return lite::RET_ERROR;
58 }
59 return lite::RET_OK;
60 }
61
AttrAdjust(const PrimitivePtr & prim,const std::string & name) const62 STATUS PrimitiveMapper::AttrAdjust(const PrimitivePtr &prim, const std::string &name) const {
63 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
64 auto value_ptr = prim->GetAttr(name);
65 if (value_ptr == nullptr) {
66 MS_LOG(WARNING) << prim->name() << " has no attr " << name;
67 return lite::RET_OK;
68 }
69 if (utils::isa<ValueSequencePtr>(value_ptr)) {
70 auto val_seq_ptr = value_ptr->cast<ValueSequencePtr>();
71 CHECK_NULL_RETURN(val_seq_ptr);
72 ValuePtr first_val = nullptr;
73 if (!val_seq_ptr->value().empty()) {
74 first_val = val_seq_ptr->value().front();
75 }
76 CHECK_NULL_RETURN(first_val);
77 CHECK_NULL_RETURN(first_val->type());
78 if (first_val->type()->number_type() != kNumberTypeInt64) {
79 MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
80 return lite::RET_ERROR;
81 }
82 } else {
83 CHECK_NULL_RETURN(value_ptr->type());
84 if (value_ptr->type()->number_type() != kNumberTypeInt64) {
85 MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
86 return lite::RET_ERROR;
87 }
88 }
89 auto origin_value = opt::CastToInt(value_ptr);
90 if (origin_value.size() == kCommonAttrValueNum) {
91 // expand to 4
92 int64_t format = Format::NCHW;
93 if (prim->GetAttr(ops::kFormat) != nullptr) {
94 format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
95 }
96 std::vector<int64_t> new_value = {1, 1, static_cast<int64_t>(origin_value[0]),
97 static_cast<int64_t>(origin_value[1])};
98 if (format == Format::NHWC) {
99 std::vector<int64_t> tmp = {1, static_cast<int64_t>(origin_value[0]), static_cast<int64_t>(origin_value[1]), 1};
100 new_value.swap(tmp);
101 }
102 prim->AddAttr(name, MakeValue(new_value));
103 }
104 return lite::RET_OK;
105 }
106
AdjustCaffePoolAttr(const std::string & src_prim_name,const PrimitivePtr & dst_prim) const107 void PrimitiveMapper::AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const {
108 int64_t mode = src_prim_name == ops::kNameAvgPoolFusion ? 1 : 0;
109 dst_prim->AddAttr(ops::kMode, MakeValue(mode));
110
111 auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
112 if (run_mode_val == nullptr) {
113 MS_LOG(INFO) << "There is no attr run mode";
114 return;
115 }
116 auto run_mode = GetValue<int64_t>(run_mode_val);
117 int64_t run_mode_ge = run_mode == RoundMode::FLOOR ? 1 : 0;
118 dst_prim->set_attr(ops::kRoundMode, MakeValue(run_mode_ge));
119 }
120
AdjustOnnxPoolAttr(const std::string & src_prim_name,const PrimitivePtr & dst_prim) const121 void PrimitiveMapper::AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const {
122 auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
123 static std::map<int64_t, std::string> kPadModToStrMap = {
124 {PadMode::PAD, "CALCULATED"},
125 {PadMode::SAME, "SAME"},
126 {PadMode::VALID, "VALID"},
127 };
128 if (pad_mode_val) {
129 auto pad_mode = GetValue<int64_t>(pad_mode_val);
130 std::string padding_mode = "CALCULATED";
131 if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
132 padding_mode = kPadModToStrMap[pad_mode];
133 }
134 if ((src_prim_name == ops::kNameMaxPool || dst_prim->name() == ops::kNameAvgPool) && padding_mode == "CALCULATED") {
135 padding_mode = "VALID";
136 }
137 std::string pad_mode_name = src_prim_name == acl::kNameMaxPoolV3 ? kNamePaddingMode : ops::kPadMode;
138 dst_prim->AddAttr(pad_mode_name, MakeValue(padding_mode));
139 } else {
140 MS_LOG(INFO) << "There is no attr pad mode";
141 }
142 auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
143 if (run_mode_val) {
144 int64_t run_mode = GetValue<int64_t>(run_mode_val);
145 bool ceil_mode = run_mode == RoundMode::CEIL;
146 dst_prim->AddAttr(kNameCeilMode, MakeValue(ceil_mode));
147 } else {
148 MS_LOG(INFO) << "There is no attr run mode";
149 }
150 }
151
AdjustPoolAttr(int fmk_type,const std::string & src_prim_name,const PrimitivePtr & dst_prim) const152 STATUS PrimitiveMapper::AdjustPoolAttr(int fmk_type, const std::string &src_prim_name,
153 const PrimitivePtr &dst_prim) const {
154 if (fmk_type == converter::kFmkTypeCaffe) {
155 AdjustCaffePoolAttr(src_prim_name, dst_prim);
156 return lite::RET_OK;
157 } else if (fmk_type == converter::kFmkTypeOnnx) {
158 AdjustOnnxPoolAttr(src_prim_name, dst_prim);
159 }
160 // adjust common attr
161 MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
162 auto status = AttrAdjust(dst_prim, ops::kKernelSize);
163 if (status != lite::RET_OK) {
164 MS_LOG(ERROR) << "Adjust kernel size failed.";
165 return status;
166 }
167 status = AttrAdjust(dst_prim, ops::kStrides);
168 if (status != lite::RET_OK) {
169 MS_LOG(ERROR) << "adjust strides failed.";
170 return status;
171 }
172 return lite::RET_OK;
173 }
174
MoveAttrMap(const CNodePtr & cnode,const PrimitivePtr & dst_prim) const175 STATUS PrimitiveMapper::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &dst_prim) const {
176 ValueNodePtr value_node = nullptr;
177 PrimitivePtr src_prim = nullptr;
178 if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
179 MS_LOG(ERROR) << "Get primitive from cnode failed.";
180 return lite::RET_ERROR;
181 }
182 MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
183 dst_prim->SetAttrs(src_prim->attrs());
184 value_node->set_value(dst_prim);
185 return lite::RET_OK;
186 }
187
AddFloatAttrToInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const PrimitivePtr & dst_prim,const std::string & attr_name,bool empty_shape) const188 STATUS PrimitiveMapper::AddFloatAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
189 const PrimitivePtr &dst_prim, const std::string &attr_name,
190 bool empty_shape) const {
191 MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
192 auto attr_val = dst_prim->GetAttr(attr_name);
193 if (attr_val == nullptr) {
194 MS_LOG(INFO) << "There is no attr: " << attr_name;
195 return lite::RET_OK;
196 }
197 auto param_name = cnode->fullname_with_scope() + "_" + attr_name;
198 auto inputs = cnode->inputs();
199 auto value_data = GetValue<float>(attr_val);
200 auto param_node = opt::BuildFloatValueParameterNode(func_graph, value_data, param_name, empty_shape);
201
202 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "param_node is nullptr.");
203 param_node->set_debug_info(std::make_shared<NodeDebugInfo>(param_name));
204 inputs.push_back(param_node);
205 cnode->set_inputs(inputs);
206 return lite::RET_OK;
207 }
208
AddIntVecAttrToInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const PrimitivePtr & dst_prim,const std::string & attr_name) const209 STATUS PrimitiveMapper::AddIntVecAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
210 const PrimitivePtr &dst_prim, const std::string &attr_name) const {
211 MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
212 auto attr_val = dst_prim->GetAttr(attr_name);
213 if (attr_val == nullptr) {
214 MS_LOG(INFO) << "There is no attr: " << attr_name;
215 return lite::RET_OK;
216 }
217 auto param_name = cnode->fullname_with_scope() + "_" + attr_name;
218 auto inputs = cnode->inputs();
219 auto value_data = opt::CastToVec2DInt(attr_val);
220 auto param_node = opt::BuildIntVec2DParameterNode(func_graph, value_data, param_name);
221
222 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "param_node is nullptr.");
223 param_node->set_debug_info(std::make_shared<NodeDebugInfo>(param_name));
224 inputs.push_back(param_node);
225 cnode->set_inputs(inputs);
226 return lite::RET_OK;
227 }
228
AddIntAttrToInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const PrimitivePtr & dst_prim,const std::string & attr_name,bool empty_shape) const229 STATUS PrimitiveMapper::AddIntAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
230 const PrimitivePtr &dst_prim, const std::string &attr_name,
231 bool empty_shape) const {
232 MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
233 auto attr_val = dst_prim->GetAttr(attr_name);
234 if (attr_val == nullptr) {
235 MS_LOG(INFO) << "There is no attr: " << attr_name;
236 return lite::RET_OK;
237 }
238 auto param_name = cnode->fullname_with_scope() + "_" + attr_name;
239 auto inputs = cnode->inputs();
240 auto value_data = opt::CastToInt(attr_val);
241 if (value_data.size() < 1) {
242 MS_LOG(ERROR) << "Invalid size: " << value_data.size();
243 return lite::RET_ERROR;
244 }
245 auto param_node = opt::BuildIntValueParameterNode(func_graph, value_data[0], param_name, empty_shape);
246
247 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "param_node is nullptr.");
248 param_node->set_debug_info(std::make_shared<NodeDebugInfo>(param_name));
249 inputs.push_back(param_node);
250 cnode->set_inputs(inputs);
251 return lite::RET_OK;
252 }
253
AddAttrForDynInputPrimitive(const CNodePtr & cnode) const254 STATUS PrimitiveMapper::AddAttrForDynInputPrimitive(const CNodePtr &cnode) const {
255 CHECK_NULL_RETURN(cnode);
256 CHECK_NULL_RETURN(cnode->input(0));
257 auto value_node = cnode->input(0)->cast<ValueNodePtr>();
258 CHECK_NULL_RETURN(value_node);
259 auto prim = GetValueNode<PrimitivePtr>(value_node);
260 CHECK_NULL_RETURN(prim);
261 // add attr input num for dynamic input op
262 int64_t num = static_cast<int64_t>(cnode->size());
263 if (num > 1) {
264 prim->AddAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{num - 1, -1}));
265 }
266 return lite::RET_OK;
267 }
268
AdjustAttrFormat(const PrimitivePtr & prim,const std::string & name) const269 STATUS PrimitiveMapper::AdjustAttrFormat(const PrimitivePtr &prim, const std::string &name) const {
270 int64_t format = Format::NCHW;
271 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
272 if (prim->GetAttr(ops::kFormat) != nullptr) {
273 format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
274 }
275 std::string format_str = "NCHW";
276 if (format == Format::NHWC) {
277 format_str = "NHWC";
278 }
279 prim->AddAttr(name, MakeValue(format_str));
280 return lite::RET_OK;
281 }
282
NewCNode(const CNodePtr & cnode,const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs,const abstract::AbstractBasePtr & abstract,const std::string & name) const283 CNodePtr PrimitiveMapper::NewCNode(const CNodePtr &cnode, const PrimitivePtr &primitive,
284 const std::vector<AnfNodePtr> &inputs, const abstract::AbstractBasePtr &abstract,
285 const std::string &name) const {
286 auto func_graph = cnode->func_graph();
287 if (func_graph == nullptr) {
288 MS_LOG(ERROR) << "Failed to NewCNode, funcGraph cannot be nullptr";
289 return nullptr;
290 }
291 auto manager = func_graph->manager();
292 if (manager == nullptr) {
293 MS_LOG(ERROR) << "Failed to NewCNode, FuncGraph manager cannot be nullptr";
294 return nullptr;
295 }
296 auto new_node = func_graph->NewCNode(primitive, inputs);
297 if (new_node == nullptr) {
298 MS_LOG(ERROR) << "Failed to create node " << name << " for node " << cnode->fullname_with_scope();
299 return nullptr;
300 }
301 new_node->set_fullname_with_scope(name);
302 for (size_t i = 0; i < inputs.size(); i++) {
303 manager->SetEdge(new_node, i + 1, inputs[i]);
304 }
305 new_node->set_abstract(abstract);
306 return new_node;
307 }
308
NewCNode(const CNodePtr & cnode,const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs,const ShapeVector & shape,TypeId type_id,const std::string & name) const309 CNodePtr PrimitiveMapper::NewCNode(const CNodePtr &cnode, const PrimitivePtr &primitive,
310 const std::vector<AnfNodePtr> &inputs, const ShapeVector &shape, TypeId type_id,
311 const std::string &name) const {
312 auto abstract =
313 std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), std::make_shared<abstract::Shape>(shape));
314 return NewCNode(cnode, primitive, inputs, abstract, name);
315 }
316 } // namespace lite
317 } // namespace mindspore
318