1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/runtime_pass.h"
18 #include "nnacl/conv_parameter.h"
19
20 namespace mindspore::lite {
21 #ifndef RUNTIME_PASS_CLIP
22 namespace {
23 const constexpr int kMaxDepth = 2048;
24 }
25
ChangeTensorDesc(Tensor * tensor,Format specified_format)26 void ChangeTensorDesc(Tensor *tensor, Format specified_format) {
27 if (tensor->shape().size() == DIMENSION_4D) {
28 auto batch = tensor->Batch();
29 auto height = tensor->Height();
30 auto width = tensor->Width();
31 auto channel = tensor->Channel();
32 if (specified_format == NHWC) {
33 tensor->set_shape({batch, height, width, channel});
34 }
35 if (specified_format == NCHW || specified_format == NC4HW4) {
36 tensor->set_shape({batch, channel, height, width});
37 }
38 }
39 tensor->set_format(specified_format);
40 return;
41 }
42
Nc4hw4PassReplace(std::vector<kernel::KernelExec * > * kernels,std::vector<Tensor * > * tensors,size_t index)43 void Nc4hw4PassReplace(std::vector<kernel::KernelExec *> *kernels, std::vector<Tensor *> *tensors, size_t index) {
44 kernel::KernelExec *conv_kernel = kernels->at(index);
45 kernel::KernelExec *transpose_kernel = conv_kernel->out_kernels().front();
46 kernel::KernelExec *c4_kernel = transpose_kernel->out_kernels().front();
47 kernel::KernelExec *transpose2_kernel = c4_kernel->out_kernels().front();
48 std::vector<kernel::KernelExec *> end_kernels = transpose2_kernel->out_kernels();
49
50 /* tensor */
51 {
52 /* transpose_kernel */
53 Tensor *transpose_param_tensor = transpose_kernel->in_tensors().at(1);
54 (void)VectorSetNull(tensors, transpose_param_tensor);
55 delete transpose_param_tensor;
56 transpose_param_tensor = nullptr;
57
58 Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
59 ChangeTensorDesc(conv_out_tensor, NC4HW4);
60 Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
61 c4_kernel->set_in_tensor(conv_out_tensor, 0);
62 (void)VectorSetNull(tensors, c4_input_tensor);
63 delete c4_input_tensor;
64 c4_input_tensor = nullptr;
65 }
66 {
67 /* transpose2_kernel */
68 Tensor *transpose_param_tensor = transpose2_kernel->in_tensors().at(1);
69 (void)VectorSetNull(tensors, transpose_param_tensor);
70 delete transpose_param_tensor;
71 transpose_param_tensor = nullptr;
72
73 Tensor *nwhc_tensor = c4_kernel->out_tensors().front();
74 ChangeTensorDesc(nwhc_tensor, NHWC);
75 for (auto end : end_kernels) {
76 end->set_in_tensor(nwhc_tensor, 0);
77 }
78 Tensor *trans_out = transpose2_kernel->out_tensors().front();
79 (void)VectorSetNull(tensors, trans_out);
80 delete trans_out;
81 trans_out = nullptr;
82 }
83
84 /* kernel */
85 (void)VectorErase(kernels, transpose_kernel);
86 delete transpose_kernel;
87 transpose_kernel = nullptr;
88 conv_kernel->set_out_kernels({c4_kernel});
89 c4_kernel->set_in_kernels({conv_kernel});
90
91 c4_kernel->set_out_kernels(transpose2_kernel->out_kernels());
92 for (auto end : end_kernels) {
93 end->set_in_kernels({c4_kernel});
94 }
95 (void)VectorErase(kernels, transpose2_kernel);
96 delete transpose2_kernel;
97 transpose2_kernel = nullptr;
98
99 return;
100 }
101
Nc4hw4PassMatch(const std::vector<kernel::KernelExec * > * kernels,size_t index)102 bool Nc4hw4PassMatch(const std::vector<kernel::KernelExec *> *kernels, size_t index) {
103 kernel::KernelExec *start_kernel = kernels->at(index);
104 if (IsContain(Nc4hw4FormatOutOpList, start_kernel->type()) == false) {
105 return false;
106 }
107 if (start_kernel->out_kernels().size() != 1) {
108 return false;
109 }
110 MS_CHECK_TRUE_MSG(start_kernel->op_parameter() != nullptr, false, "kernel->op_parameter() is nullptr.");
111 if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
112 /* conv-depthwise and group-conv */
113 return false;
114 }
115
116 kernel::KernelExec *traspose_nhwc2nchw_kernel = start_kernel->out_kernels().front();
117 if (traspose_nhwc2nchw_kernel->type() != Nc4hw4FormatTransposeOp) {
118 return false;
119 }
120 if (traspose_nhwc2nchw_kernel->out_kernels().size() != 1) {
121 return false;
122 }
123
124 kernel::KernelExec *end_kernel = traspose_nhwc2nchw_kernel->out_kernels().front();
125 if (IsContain(Nc4hw4FormatInOpList, end_kernel->type()) == false) {
126 return false;
127 }
128 if (end_kernel->out_kernels().size() != 1) {
129 return false;
130 }
131
132 kernel::KernelExec *transpose_nchw2nhwc_kernel = end_kernel->out_kernels().front();
133 if (transpose_nchw2nhwc_kernel->type() != Nc4hw4FormatTransposeOp) {
134 return false;
135 }
136
137 /* double check ops topological sorted in kernel-list */
138 auto start_iter = find(kernels->begin(), kernels->end(), start_kernel);
139 auto start_index = std::distance(kernels->begin(), start_iter);
140 auto traspose_nhwc2nchw_iter = find(kernels->begin(), kernels->end(), traspose_nhwc2nchw_kernel);
141 auto traspose_nhwc2nchw_index = std::distance(kernels->begin(), traspose_nhwc2nchw_iter);
142 auto end_iter = find(kernels->begin(), kernels->end(), end_kernel);
143 auto end_index = std::distance(kernels->begin(), end_iter);
144 auto transpose_nchw2nhwc_iter = find(kernels->begin(), kernels->end(), transpose_nchw2nhwc_kernel);
145 auto transpose_nchw2nhwc_index = std::distance(kernels->begin(), transpose_nchw2nhwc_iter);
146 if (start_index > traspose_nhwc2nchw_index || traspose_nhwc2nchw_index > end_index ||
147 end_index > transpose_nchw2nhwc_index) {
148 return false;
149 }
150
151 return true;
152 }
153
RuntimePassValid(kernel::SubGraphKernel * subgraph)154 bool RuntimePassValid(kernel::SubGraphKernel *subgraph) {
155 if (subgraph->desc().arch != kernel::KERNEL_ARCH::kCPU) {
156 return false;
157 }
158
159 #if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX)
160 return false;
161 #endif
162
163 auto kernels = subgraph->nodes();
164
165 for (auto kernel : kernels) {
166 MS_CHECK_TRUE_MSG(kernel != nullptr, false, "kernel is nullptr.");
167 if (kernel->op_parameter() != nullptr) {
168 if (kernel->op_parameter()->quant_type_ == schema::QuantType_AwareTraining ||
169 kernel->op_parameter()->quant_type_ == schema::QuantType_PostTraining) {
170 return false;
171 }
172 }
173 }
174 return true;
175 }
176
Nc4hw4PassAct(std::vector<kernel::KernelExec * > * kernels,std::vector<Tensor * > * tensors,int i)177 void Nc4hw4PassAct(std::vector<kernel::KernelExec *> *kernels, std::vector<Tensor *> *tensors, int i) {
178 if (i > kMaxDepth) {
179 MS_LOG(ERROR) << "exceed max depth 2048, i " << i;
180 return;
181 }
182 i++;
183 size_t kernel_size = kernels->size();
184 size_t index = 0;
185 for (; index + 3 < kernel_size; index++) {
186 kernel::KernelExec *kernel = kernels->at(index);
187
188 if (kernel->subgraph_type() != kernel::kNotSubGraph) {
189 kernel::SubGraphKernel *subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
190 std::vector<kernel::KernelExec *> &particial_nodes = subgraph->nodes();
191 Nc4hw4PassAct(&particial_nodes, tensors, i);
192 }
193
194 if (Nc4hw4PassMatch(kernels, index)) {
195 Nc4hw4PassReplace(kernels, tensors, index);
196 index += 1;
197 }
198 kernel_size = kernels->size();
199 }
200 return;
201 }
202
ConvNormC4PassActReplace(const kernel::KernelExec * conv_op,const kernel::KernelExec * in_op)203 void ConvNormC4PassActReplace(const kernel::KernelExec *conv_op, const kernel::KernelExec *in_op) {
204 auto connect_tensor = conv_op->out_tensors().front();
205 if (connect_tensor->shape().size() != DIMENSION_4D) {
206 return;
207 }
208 ChangeTensorDesc(connect_tensor, NC4HW4);
209 ChangeTensorDesc(in_op->in_tensors().front(), NC4HW4);
210 }
211
ConvNormC4PassActIndex(std::vector<kernel::KernelExec * > * kernels,size_t index)212 void ConvNormC4PassActIndex(std::vector<kernel::KernelExec *> *kernels, size_t index) {
213 kernel::KernelExec *start_kernel = kernels->at(index);
214 if (start_kernel->type() != ConvNormC4OpConv2DFusion) {
215 return;
216 }
217 if (start_kernel->out_kernels().size() != 1) {
218 return;
219 }
220 if (start_kernel->op_parameter() == nullptr) {
221 return;
222 }
223 if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
224 /* conv-depthwise and group-conv */
225 return;
226 }
227
228 kernel::KernelExec *after_kernel = start_kernel->out_kernels().front();
229 if (after_kernel->type() == ConvNormC4OpActivation) {
230 if (after_kernel->out_kernels().size() != 1) {
231 return;
232 }
233 kernel::KernelExec *end_kernel = after_kernel->out_kernels().front();
234 if (end_kernel->type() == ConvNormC4OpInstanceNorm) {
235 ConvNormC4PassActReplace(start_kernel, end_kernel);
236 return;
237 }
238 return;
239 }
240
241 if (after_kernel->type() == ConvNormC4OpInstanceNorm) {
242 ConvNormC4PassActReplace(start_kernel, after_kernel);
243 return;
244 }
245
246 return;
247 }
248
ConvNormC4PassAct(std::vector<kernel::KernelExec * > * kernels)249 void ConvNormC4PassAct(std::vector<kernel::KernelExec *> *kernels) {
250 size_t kernel_size = kernels->size();
251 size_t index = 0;
252 for (; index < kernel_size; index++) {
253 ConvNormC4PassActIndex(kernels, index);
254 }
255 return;
256 }
257
DeleteRedundantTrans(std::vector<kernel::KernelExec * > * kernels,bool * changed=nullptr)258 STATUS DeleteRedundantTrans(std::vector<kernel::KernelExec *> *kernels, bool *changed = nullptr) {
259 for (auto *pre_kernel : *kernels) {
260 if (pre_kernel->subgraph_type() != kernel::kNotSubGraph) {
261 auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(pre_kernel);
262 auto &partial = sub_graph->nodes();
263 if (DeleteRedundantTrans(&partial, changed) != RET_OK) {
264 MS_LOG(ERROR) << "DeleteRedundantTrans failed in subgraph.";
265 return RET_ERROR;
266 }
267 }
268 if (pre_kernel->type() != schema::PrimitiveType_Transpose) {
269 continue;
270 }
271 if (pre_kernel->in_tensors().size() < 1 || pre_kernel->out_tensors().size() < 1) {
272 MS_LOG(ERROR) << "kernel input or output is empty.";
273 return RET_ERROR;
274 }
275 auto pre_kernel_in_tensor_shape = pre_kernel->in_tensors().at(0)->shape();
276 auto pre_kernel_out_tensor_shape = pre_kernel->out_tensors().at(0)->shape();
277 for (size_t i = 0; i < pre_kernel_out_tensor_shape.size(); i++) {
278 if (pre_kernel_out_tensor_shape[i] == -1) {
279 MS_LOG(DEBUG) << " input need do resize.";
280 return RET_OK;
281 }
282 if (pre_kernel_out_tensor_shape[i] != pre_kernel_in_tensor_shape[i] && pre_kernel_out_tensor_shape[i] != 1) {
283 MS_LOG(DEBUG) << "transpose do not delete.";
284 return RET_OK;
285 }
286 }
287 auto post_kernel_size = pre_kernel->out_kernels().size();
288 if (post_kernel_size != 1) {
289 continue;
290 }
291 auto post_kernel = pre_kernel->out_kernels().front();
292 if (post_kernel->type() != schema::PrimitiveType_Reshape) {
293 continue;
294 }
295 if (pre_kernel->in_kernels().size() != 1) {
296 continue;
297 }
298 auto pre_in_kernel = pre_kernel->in_kernels().front();
299 auto pre_output_kernels = pre_in_kernel->out_kernels();
300 auto item = find(pre_output_kernels.begin(), pre_output_kernels.end(), pre_kernel);
301 if (item == pre_output_kernels.end()) {
302 MS_LOG(ERROR) << "kernel's out_kernels is invalid.";
303 return RET_ERROR;
304 }
305 *item = post_kernel;
306 pre_in_kernel->set_out_kernels(pre_output_kernels);
307
308 auto post_in_kernels = post_kernel->in_kernels();
309 item = find(post_in_kernels.begin(), post_in_kernels.end(), pre_kernel);
310 if (item == post_in_kernels.end()) {
311 MS_LOG(ERROR) << "kernel's in_kernels is invalid.";
312 return RET_ERROR;
313 }
314 *item = pre_in_kernel;
315 post_kernel->set_in_kernels(post_in_kernels);
316 post_kernel->set_in_tensor(pre_kernel->in_tensors()[0], 0);
317 kernels->erase(find(kernels->begin(), kernels->end(), pre_kernel));
318 if (changed != nullptr) {
319 *changed = true;
320 }
321 delete pre_kernel;
322 }
323 return RET_OK;
324 }
325 #endif
326
RuntimePass(std::vector<kernel::KernelExec * > * subgraphs,std::vector<Tensor * > * tensors)327 STATUS RuntimePass(std::vector<kernel::KernelExec *> *subgraphs, std::vector<Tensor *> *tensors) {
328 #ifndef RUNTIME_PASS_CLIP
329 for (auto subgraph : *subgraphs) {
330 auto sub = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
331 if (RuntimePassValid(sub) == false) {
332 continue;
333 }
334
335 int i = 0;
336 auto &kernels = sub->nodes();
337 Nc4hw4PassAct(&kernels, tensors, i);
338 ConvNormC4PassAct(&kernels);
339 auto status = DeleteRedundantTrans(&kernels);
340 if (status != RET_OK) {
341 MS_LOG(ERROR) << "DeleteRedundantTrans failed.";
342 return RET_ERROR;
343 }
344 }
345 #endif
346 return RET_OK;
347 }
348
GraphOptimizePass(std::vector<kernel::KernelExec * > * sub_graphs)349 STATUS GraphOptimizePass(std::vector<kernel::KernelExec *> *sub_graphs) {
350 #ifndef RUNTIME_PASS_CLIP
351 for (auto subgraph : *sub_graphs) {
352 auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
353 if (RuntimePassValid(sub_graph) == false) {
354 continue;
355 }
356 auto &kernels = sub_graph->nodes();
357 bool changed = false;
358 auto status = DeleteRedundantTrans(&kernels, &changed);
359 if (changed) {
360 sub_graph->SetGraphChanged(changed);
361 }
362 if (status != RET_OK) {
363 MS_LOG(ERROR) << "DeleteRedundantTrans failed.";
364 return RET_ERROR;
365 }
366 }
367 #endif
368 return RET_OK;
369 }
370 } // namespace mindspore::lite
371