• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/runtime_pass.h"
18 #include "nnacl/conv_parameter.h"
19 
20 namespace mindspore::lite {
21 #ifndef RUNTIME_PASS_CLIP
22 namespace {
23 const constexpr int kMaxDepth = 2048;
24 }
25 
ChangeTensorDesc(Tensor * tensor,Format specified_format)26 void ChangeTensorDesc(Tensor *tensor, Format specified_format) {
27   if (tensor->shape().size() == DIMENSION_4D) {
28     auto batch = tensor->Batch();
29     auto height = tensor->Height();
30     auto width = tensor->Width();
31     auto channel = tensor->Channel();
32     if (specified_format == NHWC) {
33       tensor->set_shape({batch, height, width, channel});
34     }
35     if (specified_format == NCHW || specified_format == NC4HW4) {
36       tensor->set_shape({batch, channel, height, width});
37     }
38   }
39   tensor->set_format(specified_format);
40   return;
41 }
42 
Nc4hw4PassReplace(std::vector<kernel::KernelExec * > * kernels,std::vector<Tensor * > * tensors,size_t index)43 void Nc4hw4PassReplace(std::vector<kernel::KernelExec *> *kernels, std::vector<Tensor *> *tensors, size_t index) {
44   kernel::KernelExec *conv_kernel = kernels->at(index);
45   kernel::KernelExec *transpose_kernel = conv_kernel->out_kernels().front();
46   kernel::KernelExec *c4_kernel = transpose_kernel->out_kernels().front();
47   kernel::KernelExec *transpose2_kernel = c4_kernel->out_kernels().front();
48   std::vector<kernel::KernelExec *> end_kernels = transpose2_kernel->out_kernels();
49 
50   /* tensor */
51   {
52     /* transpose_kernel */
53     Tensor *transpose_param_tensor = transpose_kernel->in_tensors().at(1);
54     (void)VectorSetNull(tensors, transpose_param_tensor);
55     delete transpose_param_tensor;
56     transpose_param_tensor = nullptr;
57 
58     Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
59     ChangeTensorDesc(conv_out_tensor, NC4HW4);
60     Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
61     c4_kernel->set_in_tensor(conv_out_tensor, 0);
62     (void)VectorSetNull(tensors, c4_input_tensor);
63     delete c4_input_tensor;
64     c4_input_tensor = nullptr;
65   }
66   {
67     /* transpose2_kernel */
68     Tensor *transpose_param_tensor = transpose2_kernel->in_tensors().at(1);
69     (void)VectorSetNull(tensors, transpose_param_tensor);
70     delete transpose_param_tensor;
71     transpose_param_tensor = nullptr;
72 
73     Tensor *nwhc_tensor = c4_kernel->out_tensors().front();
74     ChangeTensorDesc(nwhc_tensor, NHWC);
75     for (auto end : end_kernels) {
76       end->set_in_tensor(nwhc_tensor, 0);
77     }
78     Tensor *trans_out = transpose2_kernel->out_tensors().front();
79     (void)VectorSetNull(tensors, trans_out);
80     delete trans_out;
81     trans_out = nullptr;
82   }
83 
84   /* kernel */
85   (void)VectorErase(kernels, transpose_kernel);
86   delete transpose_kernel;
87   transpose_kernel = nullptr;
88   conv_kernel->set_out_kernels({c4_kernel});
89   c4_kernel->set_in_kernels({conv_kernel});
90 
91   c4_kernel->set_out_kernels(transpose2_kernel->out_kernels());
92   for (auto end : end_kernels) {
93     end->set_in_kernels({c4_kernel});
94   }
95   (void)VectorErase(kernels, transpose2_kernel);
96   delete transpose2_kernel;
97   transpose2_kernel = nullptr;
98 
99   return;
100 }
101 
Nc4hw4PassMatch(const std::vector<kernel::KernelExec * > * kernels,size_t index)102 bool Nc4hw4PassMatch(const std::vector<kernel::KernelExec *> *kernels, size_t index) {
103   kernel::KernelExec *start_kernel = kernels->at(index);
104   if (IsContain(Nc4hw4FormatOutOpList, start_kernel->type()) == false) {
105     return false;
106   }
107   if (start_kernel->out_kernels().size() != 1) {
108     return false;
109   }
110   MS_CHECK_TRUE_MSG(start_kernel->op_parameter() != nullptr, false, "kernel->op_parameter() is nullptr.");
111   if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
112     /* conv-depthwise and group-conv */
113     return false;
114   }
115 
116   kernel::KernelExec *traspose_nhwc2nchw_kernel = start_kernel->out_kernels().front();
117   if (traspose_nhwc2nchw_kernel->type() != Nc4hw4FormatTransposeOp) {
118     return false;
119   }
120   if (traspose_nhwc2nchw_kernel->out_kernels().size() != 1) {
121     return false;
122   }
123 
124   kernel::KernelExec *end_kernel = traspose_nhwc2nchw_kernel->out_kernels().front();
125   if (IsContain(Nc4hw4FormatInOpList, end_kernel->type()) == false) {
126     return false;
127   }
128   if (end_kernel->out_kernels().size() != 1) {
129     return false;
130   }
131 
132   kernel::KernelExec *transpose_nchw2nhwc_kernel = end_kernel->out_kernels().front();
133   if (transpose_nchw2nhwc_kernel->type() != Nc4hw4FormatTransposeOp) {
134     return false;
135   }
136 
137   /* double check ops topological sorted in kernel-list */
138   auto start_iter = find(kernels->begin(), kernels->end(), start_kernel);
139   auto start_index = std::distance(kernels->begin(), start_iter);
140   auto traspose_nhwc2nchw_iter = find(kernels->begin(), kernels->end(), traspose_nhwc2nchw_kernel);
141   auto traspose_nhwc2nchw_index = std::distance(kernels->begin(), traspose_nhwc2nchw_iter);
142   auto end_iter = find(kernels->begin(), kernels->end(), end_kernel);
143   auto end_index = std::distance(kernels->begin(), end_iter);
144   auto transpose_nchw2nhwc_iter = find(kernels->begin(), kernels->end(), transpose_nchw2nhwc_kernel);
145   auto transpose_nchw2nhwc_index = std::distance(kernels->begin(), transpose_nchw2nhwc_iter);
146   if (start_index > traspose_nhwc2nchw_index || traspose_nhwc2nchw_index > end_index ||
147       end_index > transpose_nchw2nhwc_index) {
148     return false;
149   }
150 
151   return true;
152 }
153 
RuntimePassValid(kernel::SubGraphKernel * subgraph)154 bool RuntimePassValid(kernel::SubGraphKernel *subgraph) {
155   if (subgraph->desc().arch != kernel::KERNEL_ARCH::kCPU) {
156     return false;
157   }
158 
159 #if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX)
160   return false;
161 #endif
162 
163   auto kernels = subgraph->nodes();
164 
165   for (auto kernel : kernels) {
166     MS_CHECK_TRUE_MSG(kernel != nullptr, false, "kernel is nullptr.");
167     if (kernel->op_parameter() != nullptr) {
168       if (kernel->op_parameter()->quant_type_ == schema::QuantType_AwareTraining ||
169           kernel->op_parameter()->quant_type_ == schema::QuantType_PostTraining) {
170         return false;
171       }
172     }
173   }
174   return true;
175 }
176 
Nc4hw4PassAct(std::vector<kernel::KernelExec * > * kernels,std::vector<Tensor * > * tensors,int i)177 void Nc4hw4PassAct(std::vector<kernel::KernelExec *> *kernels, std::vector<Tensor *> *tensors, int i) {
178   if (i > kMaxDepth) {
179     MS_LOG(ERROR) << "exceed max depth 2048, i " << i;
180     return;
181   }
182   i++;
183   size_t kernel_size = kernels->size();
184   size_t index = 0;
185   for (; index + 3 < kernel_size; index++) {
186     kernel::KernelExec *kernel = kernels->at(index);
187 
188     if (kernel->subgraph_type() != kernel::kNotSubGraph) {
189       kernel::SubGraphKernel *subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
190       std::vector<kernel::KernelExec *> &particial_nodes = subgraph->nodes();
191       Nc4hw4PassAct(&particial_nodes, tensors, i);
192     }
193 
194     if (Nc4hw4PassMatch(kernels, index)) {
195       Nc4hw4PassReplace(kernels, tensors, index);
196       index += 1;
197     }
198     kernel_size = kernels->size();
199   }
200   return;
201 }
202 
ConvNormC4PassActReplace(const kernel::KernelExec * conv_op,const kernel::KernelExec * in_op)203 void ConvNormC4PassActReplace(const kernel::KernelExec *conv_op, const kernel::KernelExec *in_op) {
204   auto connect_tensor = conv_op->out_tensors().front();
205   if (connect_tensor->shape().size() != DIMENSION_4D) {
206     return;
207   }
208   ChangeTensorDesc(connect_tensor, NC4HW4);
209   ChangeTensorDesc(in_op->in_tensors().front(), NC4HW4);
210 }
211 
ConvNormC4PassActIndex(std::vector<kernel::KernelExec * > * kernels,size_t index)212 void ConvNormC4PassActIndex(std::vector<kernel::KernelExec *> *kernels, size_t index) {
213   kernel::KernelExec *start_kernel = kernels->at(index);
214   if (start_kernel->type() != ConvNormC4OpConv2DFusion) {
215     return;
216   }
217   if (start_kernel->out_kernels().size() != 1) {
218     return;
219   }
220   if (start_kernel->op_parameter() == nullptr) {
221     return;
222   }
223   if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
224     /* conv-depthwise and group-conv */
225     return;
226   }
227 
228   kernel::KernelExec *after_kernel = start_kernel->out_kernels().front();
229   if (after_kernel->type() == ConvNormC4OpActivation) {
230     if (after_kernel->out_kernels().size() != 1) {
231       return;
232     }
233     kernel::KernelExec *end_kernel = after_kernel->out_kernels().front();
234     if (end_kernel->type() == ConvNormC4OpInstanceNorm) {
235       ConvNormC4PassActReplace(start_kernel, end_kernel);
236       return;
237     }
238     return;
239   }
240 
241   if (after_kernel->type() == ConvNormC4OpInstanceNorm) {
242     ConvNormC4PassActReplace(start_kernel, after_kernel);
243     return;
244   }
245 
246   return;
247 }
248 
ConvNormC4PassAct(std::vector<kernel::KernelExec * > * kernels)249 void ConvNormC4PassAct(std::vector<kernel::KernelExec *> *kernels) {
250   size_t kernel_size = kernels->size();
251   size_t index = 0;
252   for (; index < kernel_size; index++) {
253     ConvNormC4PassActIndex(kernels, index);
254   }
255   return;
256 }
257 
DeleteRedundantTrans(std::vector<kernel::KernelExec * > * kernels,bool * changed=nullptr)258 STATUS DeleteRedundantTrans(std::vector<kernel::KernelExec *> *kernels, bool *changed = nullptr) {
259   for (auto *pre_kernel : *kernels) {
260     if (pre_kernel->subgraph_type() != kernel::kNotSubGraph) {
261       auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(pre_kernel);
262       auto &partial = sub_graph->nodes();
263       if (DeleteRedundantTrans(&partial, changed) != RET_OK) {
264         MS_LOG(ERROR) << "DeleteRedundantTrans failed in subgraph.";
265         return RET_ERROR;
266       }
267     }
268     if (pre_kernel->type() != schema::PrimitiveType_Transpose) {
269       continue;
270     }
271     if (pre_kernel->in_tensors().size() < 1 || pre_kernel->out_tensors().size() < 1) {
272       MS_LOG(ERROR) << "kernel input or output is empty.";
273       return RET_ERROR;
274     }
275     auto pre_kernel_in_tensor_shape = pre_kernel->in_tensors().at(0)->shape();
276     auto pre_kernel_out_tensor_shape = pre_kernel->out_tensors().at(0)->shape();
277     for (size_t i = 0; i < pre_kernel_out_tensor_shape.size(); i++) {
278       if (pre_kernel_out_tensor_shape[i] == -1) {
279         MS_LOG(DEBUG) << " input need do resize.";
280         return RET_OK;
281       }
282       if (pre_kernel_out_tensor_shape[i] != pre_kernel_in_tensor_shape[i] && pre_kernel_out_tensor_shape[i] != 1) {
283         MS_LOG(DEBUG) << "transpose do not delete.";
284         return RET_OK;
285       }
286     }
287     auto post_kernel_size = pre_kernel->out_kernels().size();
288     if (post_kernel_size != 1) {
289       continue;
290     }
291     auto post_kernel = pre_kernel->out_kernels().front();
292     if (post_kernel->type() != schema::PrimitiveType_Reshape) {
293       continue;
294     }
295     if (pre_kernel->in_kernels().size() != 1) {
296       continue;
297     }
298     auto pre_in_kernel = pre_kernel->in_kernels().front();
299     auto pre_output_kernels = pre_in_kernel->out_kernels();
300     auto item = find(pre_output_kernels.begin(), pre_output_kernels.end(), pre_kernel);
301     if (item == pre_output_kernels.end()) {
302       MS_LOG(ERROR) << "kernel's out_kernels is invalid.";
303       return RET_ERROR;
304     }
305     *item = post_kernel;
306     pre_in_kernel->set_out_kernels(pre_output_kernels);
307 
308     auto post_in_kernels = post_kernel->in_kernels();
309     item = find(post_in_kernels.begin(), post_in_kernels.end(), pre_kernel);
310     if (item == post_in_kernels.end()) {
311       MS_LOG(ERROR) << "kernel's in_kernels is invalid.";
312       return RET_ERROR;
313     }
314     *item = pre_in_kernel;
315     post_kernel->set_in_kernels(post_in_kernels);
316     post_kernel->set_in_tensor(pre_kernel->in_tensors()[0], 0);
317     kernels->erase(find(kernels->begin(), kernels->end(), pre_kernel));
318     if (changed != nullptr) {
319       *changed = true;
320     }
321     delete pre_kernel;
322   }
323   return RET_OK;
324 }
325 #endif
326 
RuntimePass(std::vector<kernel::KernelExec * > * subgraphs,std::vector<Tensor * > * tensors)327 STATUS RuntimePass(std::vector<kernel::KernelExec *> *subgraphs, std::vector<Tensor *> *tensors) {
328 #ifndef RUNTIME_PASS_CLIP
329   for (auto subgraph : *subgraphs) {
330     auto sub = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
331     if (RuntimePassValid(sub) == false) {
332       continue;
333     }
334 
335     int i = 0;
336     auto &kernels = sub->nodes();
337     Nc4hw4PassAct(&kernels, tensors, i);
338     ConvNormC4PassAct(&kernels);
339     auto status = DeleteRedundantTrans(&kernels);
340     if (status != RET_OK) {
341       MS_LOG(ERROR) << "DeleteRedundantTrans failed.";
342       return RET_ERROR;
343     }
344   }
345 #endif
346   return RET_OK;
347 }
348 
GraphOptimizePass(std::vector<kernel::KernelExec * > * sub_graphs)349 STATUS GraphOptimizePass(std::vector<kernel::KernelExec *> *sub_graphs) {
350 #ifndef RUNTIME_PASS_CLIP
351   for (auto subgraph : *sub_graphs) {
352     auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
353     if (RuntimePassValid(sub_graph) == false) {
354       continue;
355     }
356     auto &kernels = sub_graph->nodes();
357     bool changed = false;
358     auto status = DeleteRedundantTrans(&kernels, &changed);
359     if (changed) {
360       sub_graph->SetGraphChanged(changed);
361     }
362     if (status != RET_OK) {
363       MS_LOG(ERROR) << "DeleteRedundantTrans failed.";
364       return RET_ERROR;
365     }
366   }
367 #endif
368   return RET_OK;
369 }
370 }  // namespace mindspore::lite
371