• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/conv_pad_fusion.h"
19 #include <memory>
20 #include <vector>
21 #include "mindspore/core/ops/lite_ops.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "tools/common/tensor_util.h"
24 #include "tools/lite_exporter/fetch_content.h"
25 #include "ops/fusion/pad_fusion.h"
26 #include "ops/fusion/conv2d_fusion.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "nnacl/op_base.h"
29 #include "ops/primitive_c.h"
30 #include "ops/op_utils.h"
31 #include "src/common/utils.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace {
36 constexpr size_t kPadInputsLength = 3;
37 constexpr size_t kConvInputIndex = 1;
38 constexpr size_t kConvNoBiasLen = 3;
39 constexpr size_t kConvWithBiasLen = 4;
40 constexpr size_t kFilterDimsSize = 2;
41 constexpr size_t NHWCTopPadPos = 2;
42 constexpr size_t NCHWTopPadPos = 4;
43 constexpr size_t kTop = 0;
44 constexpr size_t kBottom = 1;
45 constexpr size_t kLeft = 2;
46 constexpr size_t kRight = 3;
47 constexpr size_t kPadDims = 4;
48 constexpr int kPadElementNum = 8;
49 
ReplaceParamsAndNodes(const FuncGraphPtr & func_graph,const CNodePtr & conv_cnode,const CNodePtr & pad_cnode,const std::string & pattern_name)50 void ReplaceParamsAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &pad_cnode,
51                            const std::string &pattern_name) {
52   auto paddings = pad_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
53   MS_ASSERT(paddings != nullptr);
54   MS_ASSERT(paddings->default_param() != nullptr);
55   auto pad_list = std::dynamic_pointer_cast<tensor::Tensor>(paddings->default_param());
56   MS_ASSERT(pad_list != nullptr);
57   MS_ASSERT(pad_list->ElementsNum() == kPadElementNum);
58   auto pad_data = static_cast<int32_t *>(pad_list->data_c());
59   MS_ASSERT(pad_data != nullptr);
60 
61   std::vector<int64_t> pad_list_data;
62   std::vector<int> zero_pos;
63   if (pattern_name == "PadConvPatternName") {
64     zero_pos = {0, 1, kLeft + NCHWTopPadPos, kRight + NCHWTopPadPos};
65     pad_list_data.push_back(pad_data[kTop + NHWCTopPadPos]);
66     pad_list_data.push_back(pad_data[kBottom + NHWCTopPadPos]);
67     pad_list_data.push_back(pad_data[kLeft + NHWCTopPadPos]);
68     pad_list_data.push_back(pad_data[kRight + NHWCTopPadPos]);
69   } else {
70     zero_pos = {0, 1, kTop + NHWCTopPadPos, kBottom + NHWCTopPadPos};
71     pad_list_data.push_back(pad_data[kTop + NCHWTopPadPos]);
72     pad_list_data.push_back(pad_data[kBottom + NCHWTopPadPos]);
73     pad_list_data.push_back(pad_data[kLeft + NCHWTopPadPos]);
74     pad_list_data.push_back(pad_data[kRight + NCHWTopPadPos]);
75   }
76   if (std::any_of(zero_pos.begin(), zero_pos.end(), [&pad_data](int pos) { return pad_data[pos] != 0; })) {
77     return;
78   }
79 
80   auto conv_primitive = ops::GetOperator<ops::Conv2DFusion>(conv_cnode->input(0));
81   MS_ASSERT(conv_primitive != nullptr);
82   int64_t conv_pad_mode = conv_primitive->GetAttr(ops::kPadMode) == nullptr ? 0 : conv_primitive->get_pad_mode();
83   if (conv_pad_mode == PadMode::PAD) {
84     auto pad_list_node = conv_primitive->GetAttr(ops::kPadList);
85     if (pad_list_node != nullptr) {
86       std::vector<int64_t> conv_pad_list = GetValue<std::vector<int64_t>>(pad_list_node);
87       if (conv_pad_list.size() == kPadDims) {
88         pad_list_data[kTop] += conv_pad_list[kTop];
89         pad_list_data[kBottom] += conv_pad_list[kBottom];
90         pad_list_data[kLeft] += conv_pad_list[kLeft];
91         pad_list_data[kRight] += conv_pad_list[kRight];
92       }
93     }
94   } else if (conv_pad_mode == PadMode::SAME) {
95     auto kernel_node = conv_primitive->GetAttr(ops::kKernelSize);
96     MS_ASSERT(kernel_node != nullptr);
97     std::vector<int64_t> kernel_list = GetValue<std::vector<int64_t>>(kernel_node);
98     if (kernel_list.size() != kFilterDimsSize) {
99       MS_LOG(ERROR) << "Filter Dims should be 2, Fusion failed! ,name:" << conv_cnode->fullname_with_scope();
100       return;
101     } else if (kernel_list[0] == kernel_list[1]) {
102       int64_t pad_size = std::floor(kernel_list[0] / 2);
103       for (size_t i = 0; i < pad_list_data.size(); ++i) {
104         pad_list_data[i] += pad_size;
105       }
106     } else {
107       int64_t top_pad_size = std::floor(kernel_list[0] / 2);
108       int64_t left_pad_size = std::floor(kernel_list[1] / 2);
109       pad_list_data[kTop] += top_pad_size;
110       pad_list_data[kBottom] += top_pad_size;
111       pad_list_data[kLeft] += left_pad_size;
112       pad_list_data[kRight] += left_pad_size;
113     }
114     conv_primitive->set_pad_mode(PadMode::PAD);
115   } else {
116     conv_primitive->set_pad_mode(PadMode::PAD);
117   }
118   conv_primitive->set_pad_list(pad_list_data);
119 
120   // delete padFusion
121   auto manager = func_graph->manager();
122   MS_ASSERT(manager != nullptr);
123   (void)manager->Replace(pad_cnode, pad_cnode->input(1));
124 }
125 
IsPrimitiveProper(const CNodePtr & pad_cnode)126 bool IsPrimitiveProper(const CNodePtr &pad_cnode) {
127   MS_ASSERT(pad_cnode != nullptr);
128   if (!utils::isa<Parameter>(pad_cnode->input(kInputIndexTwo))) {
129     return false;
130   }
131   auto pad_list = pad_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
132   auto tensor_param = pad_list->default_param();
133   if (tensor_param == nullptr) {
134     return false;
135   }
136   auto tensor = tensor_param->cast<tensor::TensorPtr>();
137   if (tensor == nullptr) {
138     return false;
139   }
140   if (tensor->data_type() != kNumberTypeInt32 && tensor->data_type() != kNumberTypeInt) {
141     return false;
142   }
143   if (tensor->data_c() == nullptr || tensor->ElementsNum() != kPadElementNum) {
144     return false;
145   }
146   auto prim = GetValueNode<PrimitiveCPtr>(pad_cnode->input(0));
147   MS_ASSERT(prim != nullptr);
148   auto pad_primitive = api::MakeShared<ops::PadFusion>(prim);
149   if (!prim->HasAttr(ops::kPaddingMode)) {
150     return false;
151   }
152   MS_ASSERT(pad_primitive != nullptr);
153   int64_t pad_mode = pad_primitive->get_padding_mode();
154   if (pad_mode != PaddingMode::CONSTANT) {
155     return false;
156   }
157   float pad_value = 0;
158   if (pad_cnode->size() == kInputSizeThree) {
159     auto pad_constant_node = pad_primitive->GetAttr(ops::kConstantValue);
160     if (pad_constant_node == nullptr) {
161       return false;
162     }
163     pad_value = GetValue<float>(pad_constant_node);
164   } else {
165     MS_ASSERT(pad_cnode->size() == kInputSizeFour);
166     if (pad_cnode->input(kInputIndexThree)->isa<CNode>()) {
167       return false;
168     }
169     lite::DataInfo data_info;
170     auto ret = lite::FetchConstData(pad_cnode, kInputIndexThree, converter::kFmkTypeMs, &data_info, false);
171     MS_CHECK_TRUE_RET(ret == lite::RET_OK, lite::RET_ERROR);
172     if (data_info.data_ptr_ == nullptr) {
173       return false;
174     }
175     if (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat) {
176       return false;
177     }
178     pad_value = *static_cast<float *>(data_info.data_ptr_);
179   }
180   if (!mindspore::lite::FloatCompare(pad_value)) {
181     return false;
182   }
183 
184   return true;
185 }
186 }  // namespace
187 
DefinePadConvPattern() const188 VectorRef ConvPadFusion::DefinePadConvPattern() const {
189   auto is_pad = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPadFusion>);
190   MS_CHECK_TRUE_RET(is_pad != nullptr, {});
191   auto is_conv = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimConv2DFusion>);
192   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
193   auto is_param = std::make_shared<CondVar>(IsParamNode);
194   MS_CHECK_TRUE_RET(is_param != nullptr, {});
195   auto is_seq_var = std::make_shared<SeqVar>();
196   MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
197   return VectorRef({is_conv, is_pad, is_param, is_seq_var});
198 }
199 
DefinePadTransposeConvPattern() const200 VectorRef ConvPadFusion::DefinePadTransposeConvPattern() const {
201   auto is_pad = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPadFusion>);
202   MS_CHECK_TRUE_RET(is_pad != nullptr, {});
203   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
204   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
205   auto is_param_perm = std::make_shared<CondVar>(IsParamNode);
206   MS_CHECK_TRUE_RET(is_param_perm != nullptr, {});
207   VectorRef transpose_conv_ref = VectorRef({is_transpose, is_pad, is_param_perm});
208 
209   auto is_conv = std::make_shared<CondVar>(IsConvNode);
210   MS_CHECK_TRUE_RET(is_conv != nullptr, {});
211   auto is_param_weight = std::make_shared<CondVar>(IsParamNode);
212   MS_CHECK_TRUE_RET(is_param_weight != nullptr, {});
213   auto is_seq_var = std::make_shared<SeqVar>();
214   MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
215   VectorRef trans_conv_ref = VectorRef({is_conv, transpose_conv_ref, is_param_weight, is_seq_var});
216   return trans_conv_ref;
217 }
218 
DefinePatterns() const219 std::unordered_map<std::string, VectorRef> ConvPadFusion::DefinePatterns() const {
220   std::unordered_map<std::string, VectorRef> patterns;
221   patterns["PadConvPatternName"] = DefinePadConvPattern();
222   patterns["PadTransposeConvPatternName"] = DefinePadTransposeConvPattern();
223   return patterns;
224 }
225 
Process(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const226 AnfNodePtr ConvPadFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph,
227                                   const AnfNodePtr &node, const EquivPtr &equiv) const {
228   if (func_graph == nullptr || node == nullptr) {
229     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
230     return nullptr;
231   }
232 
233   auto conv_cnode = node->cast<CNodePtr>();
234   MS_CHECK_TRUE_RET(conv_cnode != nullptr, nullptr);
235   if (IsMarkedTrainOp(conv_cnode)) {
236     return nullptr;
237   }
238   if (conv_cnode->size() != kConvWithBiasLen && conv_cnode->size() != kConvNoBiasLen) {
239     MS_LOG(INFO) << "conv node inputs error ,name:" << conv_cnode->fullname_with_scope();
240     return nullptr;
241   }
242   CNodePtr pad_cnode = nullptr;
243   if (pattern_name == "PadTransposeConvPatternName") {
244     auto transpose_cnode = conv_cnode->input(1)->cast<CNodePtr>();
245     MS_CHECK_TRUE_RET(transpose_cnode != nullptr, nullptr);
246     if (IsMarkedTrainOp(transpose_cnode)) {
247       return nullptr;
248     }
249     if (IsMultiOutputTensors(func_graph, transpose_cnode)) {
250       MS_LOG(INFO) << "transpose node is used as input by multiple cnodes, Fusion failed! ,name:"
251                    << transpose_cnode->fullname_with_scope();
252       return nullptr;
253     }
254     MS_ASSERT(transpose_cnode != nullptr);
255     pad_cnode = transpose_cnode->input(1)->cast<CNodePtr>();
256   } else {
257     pad_cnode = conv_cnode->input(1)->cast<CNodePtr>();
258     if (IsMarkedTrainOp(pad_cnode)) {
259       return nullptr;
260     }
261   }
262   MS_CHECK_TRUE_RET(pad_cnode != nullptr, nullptr);
263 
264   if (IsMultiOutputTensors(func_graph, pad_cnode)) {
265     MS_LOG(INFO) << "pad node is used as input by multiple cnodes, Fusion failed! ,name:"
266                  << pad_cnode->fullname_with_scope();
267     return nullptr;
268   }
269 
270   if (pad_cnode->size() != kInputSizeThree && pad_cnode->size() != kInputSizeFour) {
271     MS_LOG(INFO) << "pad node inputs error ,name:" << pad_cnode->fullname_with_scope();
272     return nullptr;
273   }
274 
275   if (!IsPrimitiveProper(pad_cnode)) {
276     MS_LOG(INFO) << conv_cnode->fullname_with_scope() << " does not match with previous "
277                  << pad_cnode->fullname_with_scope() << " op. Fusion failed!";
278     return nullptr;
279   }
280 
281   ReplaceParamsAndNodes(func_graph, conv_cnode, pad_cnode, pattern_name);
282   return nullptr;
283 }
284 }  // namespace opt
285 }  // namespace mindspore
286