• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
18 #include <map>
19 #include <vector>
20 #include "tools/converter/adapter/acl/common/utils.h"
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
23 #include "ir/graph_utils.h"
24 #include "include/errorcode.h"
25 #include "include/registry/converter_context.h"
26 #include "ops/op_utils.h"
27 #include "ops/fusion/avg_pool_fusion.h"
28 #include "ops/fusion/max_pool_fusion.h"
29 #include "plugin/device/cpu/kernel/nnacl/op_base.h"
30 #include "src/common/log_util.h"
31 
32 namespace mindspore {
33 namespace lite {
34 namespace {
35 constexpr auto kCommonAttrValueNum = 2;
36 constexpr auto kNamePaddingMode = "padding_mode";
37 constexpr auto kNameCeilMode = "ceil_mode";
38 }  // namespace
39 
Mapper(const CNodePtr & cnode)40 STATUS PrimitiveMapper::Mapper(const CNodePtr &cnode) { return lite::RET_OK; }
41 
GetValueNodeAndPrimFromCnode(const CNodePtr & cnode,ValueNodePtr * value_node,PrimitivePtr * prim_ptr) const42 STATUS PrimitiveMapper::GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, ValueNodePtr *value_node,
43                                                      PrimitivePtr *prim_ptr) const {
44   CHECK_NULL_RETURN(cnode);
45   CHECK_NULL_RETURN(value_node);
46   CHECK_NULL_RETURN(prim_ptr);
47   CHECK_NULL_RETURN(cnode->input(0));
48 
49   *value_node = cnode->input(0)->cast<ValueNodePtr>();
50   if (*value_node == nullptr) {
51     MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] is nullptr.";
52     return lite::RET_ERROR;
53   }
54   *prim_ptr = GetValueNode<PrimitivePtr>(*value_node);
55   if (*prim_ptr == nullptr) {
56     MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] cast to primitive failed.";
57     return lite::RET_ERROR;
58   }
59   return lite::RET_OK;
60 }
61 
AttrAdjust(const PrimitivePtr & prim,const std::string & name) const62 STATUS PrimitiveMapper::AttrAdjust(const PrimitivePtr &prim, const std::string &name) const {
63   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
64   auto value_ptr = prim->GetAttr(name);
65   if (value_ptr == nullptr) {
66     MS_LOG(WARNING) << prim->name() << " has no attr " << name;
67     return lite::RET_OK;
68   }
69   if (utils::isa<ValueSequencePtr>(value_ptr)) {
70     auto val_seq_ptr = value_ptr->cast<ValueSequencePtr>();
71     CHECK_NULL_RETURN(val_seq_ptr);
72     ValuePtr first_val = nullptr;
73     if (!val_seq_ptr->value().empty()) {
74       first_val = val_seq_ptr->value().front();
75     }
76     CHECK_NULL_RETURN(first_val);
77     CHECK_NULL_RETURN(first_val->type());
78     if (first_val->type()->number_type() != kNumberTypeInt64) {
79       MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
80       return lite::RET_ERROR;
81     }
82   } else {
83     CHECK_NULL_RETURN(value_ptr->type());
84     if (value_ptr->type()->number_type() != kNumberTypeInt64) {
85       MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
86       return lite::RET_ERROR;
87     }
88   }
89   auto origin_value = opt::CastToInt(value_ptr);
90   if (origin_value.size() == kCommonAttrValueNum) {
91     // expand to 4
92     int64_t format = Format::NCHW;
93     if (prim->GetAttr(ops::kFormat) != nullptr) {
94       format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
95     }
96     std::vector<int64_t> new_value = {1, 1, static_cast<int64_t>(origin_value[0]),
97                                       static_cast<int64_t>(origin_value[1])};
98     if (format == Format::NHWC) {
99       std::vector<int64_t> tmp = {1, static_cast<int64_t>(origin_value[0]), static_cast<int64_t>(origin_value[1]), 1};
100       new_value.swap(tmp);
101     }
102     prim->AddAttr(name, MakeValue(new_value));
103   }
104   return lite::RET_OK;
105 }
106 
AdjustCaffePoolAttr(const std::string & src_prim_name,const PrimitivePtr & dst_prim) const107 void PrimitiveMapper::AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const {
108   int64_t mode = src_prim_name == ops::kNameAvgPoolFusion ? 1 : 0;
109   dst_prim->AddAttr(ops::kMode, MakeValue(mode));
110 
111   auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
112   if (run_mode_val == nullptr) {
113     MS_LOG(INFO) << "There is no attr run mode";
114     return;
115   }
116   auto run_mode = GetValue<int64_t>(run_mode_val);
117   int64_t run_mode_ge = run_mode == RoundMode::FLOOR ? 1 : 0;
118   dst_prim->set_attr(ops::kRoundMode, MakeValue(run_mode_ge));
119 }
120 
AdjustOnnxPoolAttr(const std::string & src_prim_name,const PrimitivePtr & dst_prim) const121 void PrimitiveMapper::AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const {
122   auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
123   static std::map<int64_t, std::string> kPadModToStrMap = {
124     {PadMode::PAD, "CALCULATED"},
125     {PadMode::SAME, "SAME"},
126     {PadMode::VALID, "VALID"},
127   };
128   if (pad_mode_val) {
129     auto pad_mode = GetValue<int64_t>(pad_mode_val);
130     std::string padding_mode = "CALCULATED";
131     if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
132       padding_mode = kPadModToStrMap[pad_mode];
133     }
134     if ((src_prim_name == ops::kNameMaxPool || dst_prim->name() == ops::kNameAvgPool) && padding_mode == "CALCULATED") {
135       padding_mode = "VALID";
136     }
137     std::string pad_mode_name = src_prim_name == acl::kNameMaxPoolV3 ? kNamePaddingMode : ops::kPadMode;
138     dst_prim->AddAttr(pad_mode_name, MakeValue(padding_mode));
139   } else {
140     MS_LOG(INFO) << "There is no attr pad mode";
141   }
142   auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
143   if (run_mode_val) {
144     int64_t run_mode = GetValue<int64_t>(run_mode_val);
145     bool ceil_mode = run_mode == RoundMode::CEIL;
146     dst_prim->AddAttr(kNameCeilMode, MakeValue(ceil_mode));
147   } else {
148     MS_LOG(INFO) << "There is no attr run mode";
149   }
150 }
151 
AdjustPoolAttr(int fmk_type,const std::string & src_prim_name,const PrimitivePtr & dst_prim) const152 STATUS PrimitiveMapper::AdjustPoolAttr(int fmk_type, const std::string &src_prim_name,
153                                        const PrimitivePtr &dst_prim) const {
154   if (fmk_type == converter::kFmkTypeCaffe) {
155     AdjustCaffePoolAttr(src_prim_name, dst_prim);
156     return lite::RET_OK;
157   } else if (fmk_type == converter::kFmkTypeOnnx) {
158     AdjustOnnxPoolAttr(src_prim_name, dst_prim);
159   }
160   // adjust common attr
161   MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
162   auto status = AttrAdjust(dst_prim, ops::kKernelSize);
163   if (status != lite::RET_OK) {
164     MS_LOG(ERROR) << "Adjust kernel size failed.";
165     return status;
166   }
167   status = AttrAdjust(dst_prim, ops::kStrides);
168   if (status != lite::RET_OK) {
169     MS_LOG(ERROR) << "adjust strides failed.";
170     return status;
171   }
172   return lite::RET_OK;
173 }
174 
MoveAttrMap(const CNodePtr & cnode,const PrimitivePtr & dst_prim) const175 STATUS PrimitiveMapper::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &dst_prim) const {
176   ValueNodePtr value_node = nullptr;
177   PrimitivePtr src_prim = nullptr;
178   if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
179     MS_LOG(ERROR) << "Get primitive from cnode failed.";
180     return lite::RET_ERROR;
181   }
182   MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
183   dst_prim->SetAttrs(src_prim->attrs());
184   value_node->set_value(dst_prim);
185   return lite::RET_OK;
186 }
187 
AddFloatAttrToInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const PrimitivePtr & dst_prim,const std::string & attr_name,bool empty_shape) const188 STATUS PrimitiveMapper::AddFloatAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
189                                             const PrimitivePtr &dst_prim, const std::string &attr_name,
190                                             bool empty_shape) const {
191   MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
192   auto attr_val = dst_prim->GetAttr(attr_name);
193   if (attr_val == nullptr) {
194     MS_LOG(INFO) << "There is no attr: " << attr_name;
195     return lite::RET_OK;
196   }
197   auto param_name = cnode->fullname_with_scope() + "_" + attr_name;
198   auto inputs = cnode->inputs();
199   auto value_data = GetValue<float>(attr_val);
200   auto param_node = opt::BuildFloatValueParameterNode(func_graph, value_data, param_name, empty_shape);
201 
202   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "param_node is nullptr.");
203   param_node->set_debug_info(std::make_shared<NodeDebugInfo>(param_name));
204   inputs.push_back(param_node);
205   cnode->set_inputs(inputs);
206   return lite::RET_OK;
207 }
208 
AddIntVecAttrToInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const PrimitivePtr & dst_prim,const std::string & attr_name) const209 STATUS PrimitiveMapper::AddIntVecAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
210                                              const PrimitivePtr &dst_prim, const std::string &attr_name) const {
211   MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
212   auto attr_val = dst_prim->GetAttr(attr_name);
213   if (attr_val == nullptr) {
214     MS_LOG(INFO) << "There is no attr: " << attr_name;
215     return lite::RET_OK;
216   }
217   auto param_name = cnode->fullname_with_scope() + "_" + attr_name;
218   auto inputs = cnode->inputs();
219   auto value_data = opt::CastToVec2DInt(attr_val);
220   auto param_node = opt::BuildIntVec2DParameterNode(func_graph, value_data, param_name);
221 
222   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "param_node is nullptr.");
223   param_node->set_debug_info(std::make_shared<NodeDebugInfo>(param_name));
224   inputs.push_back(param_node);
225   cnode->set_inputs(inputs);
226   return lite::RET_OK;
227 }
228 
AddIntAttrToInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const PrimitivePtr & dst_prim,const std::string & attr_name,bool empty_shape) const229 STATUS PrimitiveMapper::AddIntAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
230                                           const PrimitivePtr &dst_prim, const std::string &attr_name,
231                                           bool empty_shape) const {
232   MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
233   auto attr_val = dst_prim->GetAttr(attr_name);
234   if (attr_val == nullptr) {
235     MS_LOG(INFO) << "There is no attr: " << attr_name;
236     return lite::RET_OK;
237   }
238   auto param_name = cnode->fullname_with_scope() + "_" + attr_name;
239   auto inputs = cnode->inputs();
240   auto value_data = opt::CastToInt(attr_val);
241   if (value_data.size() < 1) {
242     MS_LOG(ERROR) << "Invalid size: " << value_data.size();
243     return lite::RET_ERROR;
244   }
245   auto param_node = opt::BuildIntValueParameterNode(func_graph, value_data[0], param_name, empty_shape);
246 
247   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "param_node is nullptr.");
248   param_node->set_debug_info(std::make_shared<NodeDebugInfo>(param_name));
249   inputs.push_back(param_node);
250   cnode->set_inputs(inputs);
251   return lite::RET_OK;
252 }
253 
AddAttrForDynInputPrimitive(const CNodePtr & cnode) const254 STATUS PrimitiveMapper::AddAttrForDynInputPrimitive(const CNodePtr &cnode) const {
255   CHECK_NULL_RETURN(cnode);
256   CHECK_NULL_RETURN(cnode->input(0));
257   auto value_node = cnode->input(0)->cast<ValueNodePtr>();
258   CHECK_NULL_RETURN(value_node);
259   auto prim = GetValueNode<PrimitivePtr>(value_node);
260   CHECK_NULL_RETURN(prim);
261   // add attr input num for dynamic input op
262   int64_t num = static_cast<int64_t>(cnode->size());
263   if (num > 1) {
264     prim->AddAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{num - 1, -1}));
265   }
266   return lite::RET_OK;
267 }
268 
AdjustAttrFormat(const PrimitivePtr & prim,const std::string & name) const269 STATUS PrimitiveMapper::AdjustAttrFormat(const PrimitivePtr &prim, const std::string &name) const {
270   int64_t format = Format::NCHW;
271   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
272   if (prim->GetAttr(ops::kFormat) != nullptr) {
273     format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
274   }
275   std::string format_str = "NCHW";
276   if (format == Format::NHWC) {
277     format_str = "NHWC";
278   }
279   prim->AddAttr(name, MakeValue(format_str));
280   return lite::RET_OK;
281 }
282 
NewCNode(const CNodePtr & cnode,const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs,const abstract::AbstractBasePtr & abstract,const std::string & name) const283 CNodePtr PrimitiveMapper::NewCNode(const CNodePtr &cnode, const PrimitivePtr &primitive,
284                                    const std::vector<AnfNodePtr> &inputs, const abstract::AbstractBasePtr &abstract,
285                                    const std::string &name) const {
286   auto func_graph = cnode->func_graph();
287   if (func_graph == nullptr) {
288     MS_LOG(ERROR) << "Failed to NewCNode, funcGraph cannot be nullptr";
289     return nullptr;
290   }
291   auto manager = func_graph->manager();
292   if (manager == nullptr) {
293     MS_LOG(ERROR) << "Failed to NewCNode, FuncGraph manager cannot be nullptr";
294     return nullptr;
295   }
296   auto new_node = func_graph->NewCNode(primitive, inputs);
297   if (new_node == nullptr) {
298     MS_LOG(ERROR) << "Failed to create node " << name << " for node " << cnode->fullname_with_scope();
299     return nullptr;
300   }
301   new_node->set_fullname_with_scope(name);
302   for (size_t i = 0; i < inputs.size(); i++) {
303     manager->SetEdge(new_node, i + 1, inputs[i]);
304   }
305   new_node->set_abstract(abstract);
306   return new_node;
307 }
308 
NewCNode(const CNodePtr & cnode,const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs,const ShapeVector & shape,TypeId type_id,const std::string & name) const309 CNodePtr PrimitiveMapper::NewCNode(const CNodePtr &cnode, const PrimitivePtr &primitive,
310                                    const std::vector<AnfNodePtr> &inputs, const ShapeVector &shape, TypeId type_id,
311                                    const std::string &name) const {
312   auto abstract =
313     std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), std::make_shared<abstract::Shape>(shape));
314   return NewCNode(cnode, primitive, inputs, abstract, name);
315 }
316 }  // namespace lite
317 }  // namespace mindspore
318