1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/runtime_pass.h"
18 #include "nnacl/conv_parameter.h"
19
20 namespace {
21 const constexpr int kMaxDepth = 2048;
22 }
23
24 namespace mindspore::lite {
Nc4hw4PassReplace(std::vector<kernel::LiteKernel * > * kernels,std::vector<Tensor * > * tensors,size_t index)25 void Nc4hw4PassReplace(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors, size_t index) {
26 kernel::LiteKernel *conv_kernel = kernels->at(index);
27 kernel::LiteKernel *transpose_kernel = conv_kernel->out_kernels().front();
28 kernel::LiteKernel *c4_kernel = transpose_kernel->out_kernels().front();
29 kernel::LiteKernel *transpose2_kernel = c4_kernel->out_kernels().front();
30 std::vector<kernel::LiteKernel *> end_kernels = transpose2_kernel->out_kernels();
31
32 /* tensor */
33 {
34 /* transpose_kernel */
35 Tensor *transpose_param_tensor = transpose_kernel->in_tensors().at(1);
36 VectorSetNull(tensors, transpose_param_tensor);
37 delete transpose_param_tensor;
38 transpose_param_tensor = nullptr;
39
40 Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
41 conv_out_tensor->set_format(NC4HW4);
42 Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
43 c4_kernel->set_in_tensor(conv_out_tensor, 0);
44 VectorSetNull(tensors, c4_input_tensor);
45 delete c4_input_tensor;
46 c4_input_tensor = nullptr;
47 }
48 {
49 /* transpose2_kernel */
50 Tensor *transpose_param_tensor = transpose2_kernel->in_tensors().at(1);
51 VectorSetNull(tensors, transpose_param_tensor);
52 delete transpose_param_tensor;
53 transpose_param_tensor = nullptr;
54
55 Tensor *nwhc_tensor = c4_kernel->out_tensors().front();
56 std::vector<int> nhwc_shape = {nwhc_tensor->Batch(), nwhc_tensor->Height(), nwhc_tensor->Width(),
57 nwhc_tensor->Channel()};
58 nwhc_tensor->set_format(NHWC);
59 nwhc_tensor->set_shape(nhwc_shape);
60 for (auto end : end_kernels) {
61 end->set_in_tensor(nwhc_tensor, 0);
62 }
63 Tensor *trans_out = transpose2_kernel->out_tensors().front();
64 VectorSetNull(tensors, trans_out);
65 delete trans_out;
66 trans_out = nullptr;
67 }
68
69 /* kernel */
70 VectorErase(kernels, transpose_kernel);
71 delete transpose_kernel;
72 transpose_kernel = nullptr;
73 conv_kernel->set_out_kernels({c4_kernel});
74 c4_kernel->set_in_kernels({conv_kernel});
75
76 c4_kernel->set_out_kernels(transpose2_kernel->out_kernels());
77 for (auto end : end_kernels) {
78 end->set_in_kernels({c4_kernel});
79 }
80 VectorErase(kernels, transpose2_kernel);
81 delete transpose2_kernel;
82 transpose2_kernel = nullptr;
83
84 return;
85 }
86
Nc4hw4PassMatch(std::vector<kernel::LiteKernel * > * kernels,size_t index)87 bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
88 kernel::LiteKernel *start_kernel = kernels->at(index);
89 if (IsContain(Nc4hw4FormatOutOpList, start_kernel->type()) == false) {
90 return false;
91 }
92 if (start_kernel->out_kernels().size() != 1) {
93 return false;
94 }
95 if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
96 /* conv-depthwise and group-conv */
97 return false;
98 }
99
100 kernel::LiteKernel *traspose_nhwc2nchw_kernel = start_kernel->out_kernels().front();
101 if (traspose_nhwc2nchw_kernel->type() != Nc4hw4FormatTransposeOp) {
102 return false;
103 }
104 if (traspose_nhwc2nchw_kernel->out_kernels().size() != 1) {
105 return false;
106 }
107
108 kernel::LiteKernel *end_kernel = traspose_nhwc2nchw_kernel->out_kernels().front();
109 if (IsContain(Nc4hw4FormatInOpList, end_kernel->type()) == false) {
110 return false;
111 }
112 if (end_kernel->out_kernels().size() != 1) {
113 return false;
114 }
115
116 kernel::LiteKernel *transpose_nchw2nhwc_kernel = end_kernel->out_kernels().front();
117 if (transpose_nchw2nhwc_kernel->type() != Nc4hw4FormatTransposeOp) {
118 return false;
119 }
120
121 /* double check ops topological sorted in kernel-list */
122 auto start_iter = find(kernels->begin(), kernels->end(), start_kernel);
123 auto start_index = std::distance(kernels->begin(), start_iter);
124 auto traspose_nhwc2nchw_iter = find(kernels->begin(), kernels->end(), traspose_nhwc2nchw_kernel);
125 auto traspose_nhwc2nchw_index = std::distance(kernels->begin(), traspose_nhwc2nchw_iter);
126 auto end_iter = find(kernels->begin(), kernels->end(), end_kernel);
127 auto end_index = std::distance(kernels->begin(), end_iter);
128 auto transpose_nchw2nhwc_iter = find(kernels->begin(), kernels->end(), transpose_nchw2nhwc_kernel);
129 auto transpose_nchw2nhwc_index = std::distance(kernels->begin(), transpose_nchw2nhwc_iter);
130 if (start_index > traspose_nhwc2nchw_index || traspose_nhwc2nchw_index > end_index ||
131 end_index > transpose_nchw2nhwc_index) {
132 return false;
133 }
134
135 return true;
136 }
137
RuntimePassValid(kernel::SubGraphKernel * subgraph)138 bool RuntimePassValid(kernel::SubGraphKernel *subgraph) {
139 if (subgraph->desc().arch != kernel::KERNEL_ARCH::kCPU) {
140 return false;
141 }
142
143 auto kernels = subgraph->nodes();
144
145 for (auto kernel : kernels) {
146 if (kernel->op_parameter() != nullptr) {
147 if (kernel->op_parameter()->quant_type_ == schema::QuantType_AwareTraining ||
148 kernel->op_parameter()->quant_type_ == schema::QuantType_PostTraining) {
149 return false;
150 }
151 }
152 }
153 return true;
154 }
155
Nc4hw4PassAct(std::vector<kernel::LiteKernel * > * kernels,std::vector<Tensor * > * tensors,int i)156 void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors, int i) {
157 if (i > kMaxDepth) {
158 MS_LOG(ERROR) << "exceed max depth 2048, i " << i;
159 return;
160 }
161 i++;
162 size_t kernel_size = kernels->size();
163 size_t index = 0;
164 for (; index + 3 < kernel_size; index++) {
165 kernel::LiteKernel *kernel = kernels->at(index);
166
167 if (kernel->subgraph_type() != kernel::kNotSubGraph) {
168 kernel::SubGraphKernel *subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
169 std::vector<kernel::LiteKernel *> &particial_nodes = subgraph->nodes();
170 Nc4hw4PassAct(&particial_nodes, tensors, i);
171 }
172
173 if (Nc4hw4PassMatch(kernels, index)) {
174 Nc4hw4PassReplace(kernels, tensors, index);
175 index += 1;
176 }
177 kernel_size = kernels->size();
178 }
179 return;
180 }
181
ConvNormC4PassActReplace(kernel::LiteKernel * conv_op,kernel::LiteKernel * in_op)182 void ConvNormC4PassActReplace(kernel::LiteKernel *conv_op, kernel::LiteKernel *in_op) {
183 conv_op->out_tensors().front()->set_format(NC4HW4);
184 in_op->in_tensors().front()->set_format(NC4HW4);
185 }
186
ConvNormC4PassActIndex(std::vector<kernel::LiteKernel * > * kernels,size_t index)187 void ConvNormC4PassActIndex(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
188 kernel::LiteKernel *start_kernel = kernels->at(index);
189 if (start_kernel->type() != ConvNormC4OpConv2DFusion) {
190 return;
191 }
192 if (start_kernel->out_kernels().size() != 1) {
193 return;
194 }
195 if (reinterpret_cast<ConvParameter *>(start_kernel->op_parameter())->group_ != 1) {
196 /* conv-depthwise and group-conv */
197 return;
198 }
199
200 kernel::LiteKernel *after_kernel = start_kernel->out_kernels().front();
201 if (after_kernel->type() == ConvNormC4OpActivation) {
202 if (after_kernel->out_kernels().size() != 1) {
203 return;
204 }
205 kernel::LiteKernel *end_kernel = after_kernel->out_kernels().front();
206 if (end_kernel->type() == ConvNormC4OpInstanceNorm) {
207 ConvNormC4PassActReplace(start_kernel, end_kernel);
208 return;
209 }
210 return;
211 }
212
213 if (after_kernel->type() == ConvNormC4OpInstanceNorm) {
214 ConvNormC4PassActReplace(start_kernel, after_kernel);
215 return;
216 }
217
218 return;
219 }
220
ConvNormC4PassAct(std::vector<kernel::LiteKernel * > * kernels)221 void ConvNormC4PassAct(std::vector<kernel::LiteKernel *> *kernels) {
222 size_t kernel_size = kernels->size();
223 size_t index = 0;
224 for (; index < kernel_size; index++) {
225 ConvNormC4PassActIndex(kernels, index);
226 }
227 return;
228 }
229
RuntimePass(std::vector<kernel::LiteKernel * > * subgraphs,std::vector<Tensor * > * tensors)230 void RuntimePass(std::vector<kernel::LiteKernel *> *subgraphs, std::vector<Tensor *> *tensors) {
231 for (auto subgraph : *subgraphs) {
232 auto sub = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
233 if (RuntimePassValid(sub) == false) {
234 continue;
235 }
236
237 int i = 0;
238 auto &kernels = sub->nodes();
239 Nc4hw4PassAct(&kernels, tensors, i);
240 ConvNormC4PassAct(&kernels);
241 }
242 }
243 } // namespace mindspore::lite
244