• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/runtime_pass.h"
18 #include "nnacl/conv_parameter.h"
19 
20 namespace {
21 const constexpr int kMaxDepth = 2048;
22 }
23 
24 namespace mindspore::lite {
Nc4hw4PassReplace(std::vector<kernel::LiteKernel * > * kernels,std::vector<Tensor * > * tensors,size_t index)25 void Nc4hw4PassReplace(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors, size_t index) {
26   kernel::LiteKernel *conv_kernel = kernels->at(index);
27   kernel::LiteKernel *transpose_kernel = conv_kernel->out_kernels().front();
28   kernel::LiteKernel *c4_kernel = transpose_kernel->out_kernels().front();
29   kernel::LiteKernel *transpose2_kernel = c4_kernel->out_kernels().front();
30   std::vector<kernel::LiteKernel *> end_kernels = transpose2_kernel->out_kernels();
31 
32   /* tensor */
33   {
34     /* transpose_kernel */
35     Tensor *transpose_param_tensor = transpose_kernel->in_tensors().at(1);
36     VectorSetNull(tensors, transpose_param_tensor);
37     delete transpose_param_tensor;
38     transpose_param_tensor = nullptr;
39 
40     Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
41     conv_out_tensor->set_format(NC4HW4);
42     Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
43     c4_kernel->set_in_tensor(conv_out_tensor, 0);
44     VectorSetNull(tensors, c4_input_tensor);
45     delete c4_input_tensor;
46     c4_input_tensor = nullptr;
47   }
48   {
49     /* transpose2_kernel */
50     Tensor *transpose_param_tensor = transpose2_kernel->in_tensors().at(1);
51     VectorSetNull(tensors, transpose_param_tensor);
52     delete transpose_param_tensor;
53     transpose_param_tensor = nullptr;
54 
55     Tensor *nwhc_tensor = c4_kernel->out_tensors().front();
56     std::vector<int> nhwc_shape = {nwhc_tensor->Batch(), nwhc_tensor->Height(), nwhc_tensor->Width(),
57                                    nwhc_tensor->Channel()};
58     nwhc_tensor->set_format(NHWC);
59     nwhc_tensor->set_shape(nhwc_shape);
60     for (auto end : end_kernels) {
61       end->set_in_tensor(nwhc_tensor, 0);
62     }
63     Tensor *trans_out = transpose2_kernel->out_tensors().front();
64     VectorSetNull(tensors, trans_out);
65     delete trans_out;
66     trans_out = nullptr;
67   }
68 
69   /* kernel */
70   VectorErase(kernels, transpose_kernel);
71   delete transpose_kernel;
72   transpose_kernel = nullptr;
73   conv_kernel->set_out_kernels({c4_kernel});
74   c4_kernel->set_in_kernels({conv_kernel});
75 
76   c4_kernel->set_out_kernels(transpose2_kernel->out_kernels());
77   for (auto end : end_kernels) {
78     end->set_in_kernels({c4_kernel});
79   }
80   VectorErase(kernels, transpose2_kernel);
81   delete transpose2_kernel;
82   transpose2_kernel = nullptr;
83 
84   return;
85 }
86 
Nc4hw4PassMatch(std::vector<kernel::LiteKernel * > * kernels,size_t index)87 bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
88   kernel::LiteKernel *start_kernel = kernels->at(index);
89   if (IsContain(Nc4hw4FormatOutOpList, start_kernel->type()) == false) {
90     return false;
91   }
92   if (start_kernel->out_kernels().size() != 1) {
93     return false;
94   }
95   if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
96     /* conv-depthwise and group-conv */
97     return false;
98   }
99 
100   kernel::LiteKernel *traspose_nhwc2nchw_kernel = start_kernel->out_kernels().front();
101   if (traspose_nhwc2nchw_kernel->type() != Nc4hw4FormatTransposeOp) {
102     return false;
103   }
104   if (traspose_nhwc2nchw_kernel->out_kernels().size() != 1) {
105     return false;
106   }
107 
108   kernel::LiteKernel *end_kernel = traspose_nhwc2nchw_kernel->out_kernels().front();
109   if (IsContain(Nc4hw4FormatInOpList, end_kernel->type()) == false) {
110     return false;
111   }
112   if (end_kernel->out_kernels().size() != 1) {
113     return false;
114   }
115 
116   kernel::LiteKernel *transpose_nchw2nhwc_kernel = end_kernel->out_kernels().front();
117   if (transpose_nchw2nhwc_kernel->type() != Nc4hw4FormatTransposeOp) {
118     return false;
119   }
120 
121   /* double check ops topological sorted in kernel-list */
122   auto start_iter = find(kernels->begin(), kernels->end(), start_kernel);
123   auto start_index = std::distance(kernels->begin(), start_iter);
124   auto traspose_nhwc2nchw_iter = find(kernels->begin(), kernels->end(), traspose_nhwc2nchw_kernel);
125   auto traspose_nhwc2nchw_index = std::distance(kernels->begin(), traspose_nhwc2nchw_iter);
126   auto end_iter = find(kernels->begin(), kernels->end(), end_kernel);
127   auto end_index = std::distance(kernels->begin(), end_iter);
128   auto transpose_nchw2nhwc_iter = find(kernels->begin(), kernels->end(), transpose_nchw2nhwc_kernel);
129   auto transpose_nchw2nhwc_index = std::distance(kernels->begin(), transpose_nchw2nhwc_iter);
130   if (start_index > traspose_nhwc2nchw_index || traspose_nhwc2nchw_index > end_index ||
131       end_index > transpose_nchw2nhwc_index) {
132     return false;
133   }
134 
135   return true;
136 }
137 
RuntimePassValid(kernel::SubGraphKernel * subgraph)138 bool RuntimePassValid(kernel::SubGraphKernel *subgraph) {
139   if (subgraph->desc().arch != kernel::KERNEL_ARCH::kCPU) {
140     return false;
141   }
142 
143   auto kernels = subgraph->nodes();
144 
145   for (auto kernel : kernels) {
146     if (kernel->op_parameter() != nullptr) {
147       if (kernel->op_parameter()->quant_type_ == schema::QuantType_AwareTraining ||
148           kernel->op_parameter()->quant_type_ == schema::QuantType_PostTraining) {
149         return false;
150       }
151     }
152   }
153   return true;
154 }
155 
Nc4hw4PassAct(std::vector<kernel::LiteKernel * > * kernels,std::vector<Tensor * > * tensors,int i)156 void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors, int i) {
157   if (i > kMaxDepth) {
158     MS_LOG(ERROR) << "exceed max depth 2048, i " << i;
159     return;
160   }
161   i++;
162   size_t kernel_size = kernels->size();
163   size_t index = 0;
164   for (; index + 3 < kernel_size; index++) {
165     kernel::LiteKernel *kernel = kernels->at(index);
166 
167     if (kernel->subgraph_type() != kernel::kNotSubGraph) {
168       kernel::SubGraphKernel *subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
169       std::vector<kernel::LiteKernel *> &particial_nodes = subgraph->nodes();
170       Nc4hw4PassAct(&particial_nodes, tensors, i);
171     }
172 
173     if (Nc4hw4PassMatch(kernels, index)) {
174       Nc4hw4PassReplace(kernels, tensors, index);
175       index += 1;
176     }
177     kernel_size = kernels->size();
178   }
179   return;
180 }
181 
ConvNormC4PassActReplace(kernel::LiteKernel * conv_op,kernel::LiteKernel * in_op)182 void ConvNormC4PassActReplace(kernel::LiteKernel *conv_op, kernel::LiteKernel *in_op) {
183   conv_op->out_tensors().front()->set_format(NC4HW4);
184   in_op->in_tensors().front()->set_format(NC4HW4);
185 }
186 
ConvNormC4PassActIndex(std::vector<kernel::LiteKernel * > * kernels,size_t index)187 void ConvNormC4PassActIndex(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
188   kernel::LiteKernel *start_kernel = kernels->at(index);
189   if (start_kernel->type() != ConvNormC4OpConv2DFusion) {
190     return;
191   }
192   if (start_kernel->out_kernels().size() != 1) {
193     return;
194   }
195   if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
196     /* conv-depthwise and group-conv */
197     return;
198   }
199 
200   kernel::LiteKernel *after_kernel = start_kernel->out_kernels().front();
201   if (after_kernel->type() == ConvNormC4OpActivation) {
202     if (after_kernel->out_kernels().size() != 1) {
203       return;
204     }
205     kernel::LiteKernel *end_kernel = after_kernel->out_kernels().front();
206     if (end_kernel->type() == ConvNormC4OpInstanceNorm) {
207       ConvNormC4PassActReplace(start_kernel, end_kernel);
208       return;
209     }
210     return;
211   }
212 
213   if (after_kernel->type() == ConvNormC4OpInstanceNorm) {
214     ConvNormC4PassActReplace(start_kernel, after_kernel);
215     return;
216   }
217 
218   return;
219 }
220 
ConvNormC4PassAct(std::vector<kernel::LiteKernel * > * kernels)221 void ConvNormC4PassAct(std::vector<kernel::LiteKernel *> *kernels) {
222   size_t kernel_size = kernels->size();
223   size_t index = 0;
224   for (; index < kernel_size; index++) {
225     ConvNormC4PassActIndex(kernels, index);
226   }
227   return;
228 }
229 
RuntimePass(std::vector<kernel::LiteKernel * > * subgraphs,std::vector<Tensor * > * tensors)230 void RuntimePass(std::vector<kernel::LiteKernel *> *subgraphs, std::vector<Tensor *> *tensors) {
231   for (auto subgraph : *subgraphs) {
232     auto sub = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
233     if (RuntimePassValid(sub) == false) {
234       continue;
235     }
236 
237     int i = 0;
238     auto &kernels = sub->nodes();
239     Nc4hw4PassAct(&kernels, tensors, i);
240     ConvNormC4PassAct(&kernels);
241   }
242 }
243 }  // namespace mindspore::lite
244