1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <set>
18 #include <queue>
19 #include "src/litert/pass/format_pass/eliminate_transpose.h"
20 #include "src/litert/kernel_exec_util.h"
21
22 namespace mindspore::lite::pass {
TransFullyFusion(kernel::SubGraphKernel * subgraph,kernel::KernelExec * trans_kernel0,kernel::KernelExec * trans_kernel1)23 int TransFullyFusion(kernel::SubGraphKernel *subgraph, kernel::KernelExec *trans_kernel0,
24 kernel::KernelExec *trans_kernel1) {
25 CHECK_NULL_RETURN(trans_kernel0);
26 CHECK_NULL_RETURN(trans_kernel1);
27 auto in_tensor = trans_kernel0->in_tensors().at(0);
28
29 auto out_tensor = trans_kernel1->out_tensors().at(0);
30 auto in_kernel = kernel::KernelExecUtil::FindInKernelForInTensor(trans_kernel0, in_tensor);
31 auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(trans_kernel1, out_tensor);
32 subgraph->UpdateInOutKernels(in_kernel, out_kernels, trans_kernel0, trans_kernel1);
33 auto ret = subgraph->UpdateInOutTensors(in_kernel, out_kernels, in_tensor, out_tensor, true);
34 if (ret != RET_OK) {
35 MS_LOG(ERROR) << "Update tensor failed when fusing kernel " << trans_kernel0->name() << " and "
36 << trans_kernel1->name();
37 return RET_ERROR;
38 }
39 subgraph->DropNode(trans_kernel1);
40 delete trans_kernel1;
41 if (trans_kernel0->out_kernels().empty() && !IsContain(subgraph->out_tensors(), trans_kernel0->out_tensors().at(0))) {
42 subgraph->DropNode(trans_kernel0);
43 delete trans_kernel0;
44 }
45 return RET_OK;
46 }
TransHeadTailFusion(kernel::SubGraphKernel * subgraph,kernel::KernelExec * trans_kernel0,kernel::KernelExec * trans_kernel1,const TransInfoPair & trans_info,const CreateFormatTransposeFunc & create_format_transpose_func)47 int TransHeadTailFusion(kernel::SubGraphKernel *subgraph, kernel::KernelExec *trans_kernel0,
48 kernel::KernelExec *trans_kernel1, const TransInfoPair &trans_info,
49 const CreateFormatTransposeFunc &create_format_transpose_func) {
50 CHECK_NULL_RETURN(trans_kernel0);
51 CHECK_NULL_RETURN(trans_kernel1);
52 CHECK_NULL_RETURN(create_format_transpose_func);
53 auto ctx = trans_kernel0->Context();
54 auto desc = trans_kernel0->desc();
55 auto in_tensor = trans_kernel0->in_tensors().at(0);
56 auto out_tensor = trans_kernel1->out_tensors().at(0);
57 auto in_kernel = kernel::KernelExecUtil::FindInKernelForInTensor(trans_kernel0, in_tensor);
58 auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(trans_kernel1, out_tensor);
59 subgraph->UpdateInOutKernels(in_kernel, out_kernels, trans_kernel0, trans_kernel1);
60 // new trans kernel: src_format -> dst_format
61 auto trans_name = trans_kernel0->name() + "_and_" + trans_kernel1->name() + "_fusion";
62 auto kernel = create_format_transpose_func(in_tensor, out_tensor, trans_info, trans_name, ctx, desc);
63 CHECK_NULL_RETURN(kernel);
64 if (in_kernel != nullptr) {
65 in_kernel->AddOutKernel(kernel);
66 kernel->AddInKernel(in_kernel);
67 }
68 for (const auto &out_kernel : out_kernels) {
69 if (in_kernel != nullptr) {
70 in_kernel->RemoveOutKernel(out_kernel);
71 out_kernel->RemoveInKernel(in_kernel);
72 }
73 out_kernel->AddInKernel(kernel);
74 kernel->AddOutKernel(out_kernel);
75 }
76 subgraph->nodes().push_back(kernel);
77
78 subgraph->DropNode(trans_kernel1);
79 delete trans_kernel1;
80 if (trans_kernel0->out_kernels().empty() && !IsContain(subgraph->out_tensors(), trans_kernel0->out_tensors().at(0))) {
81 subgraph->DropNode(trans_kernel0);
82 delete trans_kernel0;
83 }
84 return RET_OK;
85 }
86
PackConstData(Tensor * tensor,const TransInfoPair & pre_trans)87 int PackConstData(Tensor *tensor, const TransInfoPair &pre_trans) {
88 if (tensor->shape().size() != 4) {
89 MS_LOG(ERROR) << "Pack const data only valid for 4 dims tensor.";
90 return RET_OK;
91 }
92 auto allocator = tensor->allocator();
93 auto original_data = tensor->data();
94 auto original_own_data = tensor->own_data();
95
96 if (!TransTensorShapeAndFormat(tensor, pre_trans.dst_format_)) {
97 MS_LOG(ERROR) << "Transpose tensor shape and format failed";
98 return RET_ERROR;
99 }
100 tensor->set_data(nullptr);
101
102 auto ret = tensor->MallocData();
103 if (ret != RET_OK) {
104 MS_LOG(ERROR) << "Malloc new format data failed";
105 return ret;
106 }
107
108 if (original_own_data) {
109 if (allocator != nullptr) {
110 allocator->Free(original_data);
111 } else {
112 free(original_data);
113 }
114 }
115 MS_LOG(ERROR) << "Can't call TransData function.";
116 return RET_ERROR;
117 }
118
DoPreFusion(kernel::SubGraphKernel * subgraph,kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & pre_trans,const CreateFormatTransposeFunc & create_format_transpose_func)119 int DoPreFusion(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
120 const TransInfoPair &pre_trans, const CreateFormatTransposeFunc &create_format_transpose_func) {
121 CHECK_NULL_RETURN(create_format_transpose_func);
122 for (size_t i = 0; i < kernel->in_tensors().size(); i++) {
123 auto in_tensor = kernel->in_tensors().at(i);
124 if (in_tensor->IsConst()) {
125 auto ret = PackConstData(in_tensor, pre_trans);
126 if (ret != RET_OK) {
127 MS_LOG(ERROR) << "Pack tensor " << in_tensor->tensor_name() << " data failed.";
128 return RET_ERROR;
129 }
130 continue;
131 }
132 auto ret =
133 InsertPreTranspose(subgraph, kernel, all_tensors, TransInfoPair(pre_trans.dst_format_, pre_trans.src_format_), i,
134 create_format_transpose_func);
135 if (ret != RET_OK) {
136 MS_LOG(ERROR) << "Insert pre transpose for " << kernel->name() << "(index: " << i
137 << ") while eliminating transposes crossing kernel failed";
138 return RET_ERROR;
139 }
140 }
141 return RET_OK;
142 }
143
DoPostFusion(kernel::SubGraphKernel * subgraph,const kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & post_trans,const CreateFormatTransposeFunc & create_format_transpose_func)144 int DoPostFusion(kernel::SubGraphKernel *subgraph, const kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
145 const TransInfoPair &post_trans, const CreateFormatTransposeFunc &create_format_transpose_func) {
146 CHECK_NULL_RETURN(create_format_transpose_func);
147 for (size_t i = 0; i < kernel->out_tensors().size(); i++) {
148 auto tensor = kernel->out_tensors().at(i);
149 auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(kernel, tensor);
150
151 std::vector<kernel::KernelExec *> to_deletes;
152 for (const auto &out_kernel : out_kernels) {
153 TransInfoPair out_kernel_trans;
154 auto ret = GetTransposeInfo(out_kernel, &out_kernel_trans);
155 if (ret == RET_OK && IsSameTranspose(post_trans, out_kernel_trans)) {
156 (void)to_deletes.emplace_back(out_kernel);
157 continue;
158 }
159 auto in_tensor_of_out_kernel_idxes = out_kernel->FindAllInTensorIndex(tensor);
160 for (auto &in_tensor_of_out_kernel_idx : in_tensor_of_out_kernel_idxes) {
161 ret = InsertPreTranspose(subgraph, out_kernel, all_tensors,
162 TransInfoPair(post_trans.dst_format_, post_trans.src_format_),
163 in_tensor_of_out_kernel_idx, create_format_transpose_func);
164 if (ret != RET_OK) {
165 MS_LOG(ERROR) << "Insert pre transpose kernel for op: " << out_kernel->name() << " input tensor "
166 << in_tensor_of_out_kernel_idx << " failed.";
167 return RET_ERROR;
168 }
169 }
170 }
171 for (auto &to_delete : to_deletes) {
172 auto ret = subgraph->DeleteSingleWayNode(to_delete, false);
173 if (ret != RET_OK) {
174 MS_LOG(ERROR) << "Delete kernel: " << to_delete->name() << " failed.";
175 return RET_ERROR;
176 }
177 }
178 }
179 return RET_OK;
180 }
181
EliminateForSingleKernel(kernel::SubGraphKernel * subgraph,std::vector<Tensor * > * all_tensors)182 int EliminateTranspose::EliminateForSingleKernel(kernel::SubGraphKernel *subgraph, std::vector<Tensor *> *all_tensors) {
183 auto kernels = &(subgraph->nodes());
184 auto kernel_iter = kernels->begin();
185 while (kernel_iter != kernels->end()) {
186 auto kernel = *kernel_iter;
187 CHECK_NULL_RETURN(kernel);
188 TransInfoPair pre_trans;
189 TransInfoPair post_trans;
190 if (!transpose_strategy_.CrossKernelFusionPreCheck(kernel, &pre_trans, &post_trans)) {
191 (void)kernel_iter++;
192 continue;
193 }
194 auto ret = TransposeStrategy::TryTransKernelAxis(kernel, post_trans);
195 if (ret == RET_NO_CHANGE) {
196 // Some kernel can not be fusion although CrossKernelFusionPreCheck is successful. For example, whether crop can
197 // transpose axis depend on axis attribute of crop primitive.
198 (void)kernel_iter++;
199 continue;
200 }
201 if (ret != RET_OK) {
202 MS_LOG(ERROR) << "Change kernel axis " << kernel->name() << " failed.";
203 return RET_ERROR;
204 }
205
206 graph_changed_ = true;
207 ret = DoPreFusion(subgraph, kernel, all_tensors, pre_trans, this->create_format_transpose_func_);
208 if (ret != RET_OK) {
209 MS_LOG(ERROR) << "Fusion for pre transpose of " << kernel->name() << " failed.";
210 return RET_ERROR;
211 }
212 ret = DoPostFusion(subgraph, kernel, all_tensors, post_trans, this->create_format_transpose_func_);
213 if (ret != RET_OK) {
214 MS_LOG(ERROR) << "Fusion for post transpose of " << kernel->name() << " failed.";
215 return RET_ERROR;
216 }
217 kernel_iter = find(kernels->begin(), kernels->end(), kernel);
218 (void)kernel_iter++;
219 MS_LOG(INFO) << "Fuse transpose across: " << kernel->name();
220 }
221 return RET_OK;
222 }
223
HorizontalTransposeFusionPass(kernel::SubGraphKernel * subgraph)224 int EliminateTranspose::HorizontalTransposeFusionPass(kernel::SubGraphKernel *subgraph) {
225 auto in_tensors = subgraph->in_tensors();
226 std::queue<lite::Tensor *> tensor_queue;
227 for (const auto &tensor : in_tensors) {
228 tensor_queue.push(tensor);
229 }
230 std::set<lite::Tensor *> visited;
231 while (!tensor_queue.empty()) {
232 auto tensor = tensor_queue.front();
233 tensor_queue.pop();
234 visited.insert(tensor);
235 auto in_kernel = kernel::KernelExecUtil::FindInKernelForTensorInSubGraph(tensor, subgraph);
236 auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForTensorInSubGraph(tensor, subgraph);
237 for (const auto &out_kernel : out_kernels) {
238 for (const auto &out_tensor : out_kernel->out_tensors()) {
239 if (visited.find(out_tensor) == visited.end()) {
240 tensor_queue.push(out_tensor);
241 }
242 }
243 }
244
245 TransInfoPair post_trans;
246 auto count = transpose_strategy_.GetTransCount(out_kernels, &post_trans);
247 if (count <= 1) {
248 continue;
249 }
250
251 graph_changed_ = true;
252
253 kernel::KernelExec *reserve_kernel = nullptr;
254 std::vector<kernel::KernelExec *> to_deletes;
255 for (const auto &out_kernel : out_kernels) {
256 TransInfoPair tmp_trans;
257 if (GetTransposeInfo(out_kernel, &tmp_trans) != RET_OK || !IsSameTranspose(post_trans, tmp_trans)) {
258 continue;
259 }
260 if (reserve_kernel == nullptr) {
261 // firstly set value
262 reserve_kernel = out_kernel;
263 continue;
264 }
265 if (IsContain(subgraph->out_tensors(), out_kernel->out_tensors().at(0))) {
266 to_deletes.push_back(reserve_kernel);
267 reserve_kernel = out_kernel;
268 } else {
269 to_deletes.push_back(out_kernel);
270 }
271 }
272 auto reserve_tensor = reserve_kernel->out_tensors().at(0);
273
274 for (const auto &to_delete : to_deletes) {
275 if (to_delete == reserve_kernel) {
276 continue;
277 }
278
279 if (in_kernel != nullptr) {
280 in_kernel->RemoveOutKernel(to_delete);
281 to_delete->RemoveInKernel(in_kernel);
282 }
283
284 auto post_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(to_delete, to_delete->out_tensors().at(0));
285 for (const auto &post : post_kernels) {
286 to_delete->RemoveOutKernel(post);
287 post->RemoveInKernel(to_delete);
288
289 post->AddInKernel(reserve_kernel);
290 reserve_kernel->AddOutKernel(post);
291
292 auto input_indexes = post->FindAllInTensorIndex(to_delete->out_tensors().at(0));
293 for (auto &input_index : input_indexes) {
294 post->set_in_tensor(reserve_tensor, input_index);
295 }
296 }
297 subgraph->DropNode(to_delete);
298 delete to_delete;
299 }
300 if (in_kernel != nullptr) {
301 MS_LOG(INFO) << "Fuse horizontal-transposes after: " << in_kernel->name();
302 } else {
303 MS_LOG(INFO) << "Fuse horizontal-transposes on tensor: " << tensor->tensor_name();
304 }
305 }
306 return RET_OK;
307 }
308
DoubleTransposeFusion(kernel::SubGraphKernel * subgraph)309 int EliminateTranspose::DoubleTransposeFusion(kernel::SubGraphKernel *subgraph) {
310 auto kernels = &(subgraph->nodes());
311 auto kernel_iter = kernels->begin();
312 while (kernel_iter != kernels->end()) {
313 auto &kernel = *kernel_iter;
314 CHECK_NULL_RETURN(kernel);
315 (void)kernel_iter++;
316
317 if (kernel->in_kernels().size() != 1) {
318 continue;
319 }
320
321 auto pre_kernel = kernel->in_kernels().at(0);
322 if (!IsContain(subgraph->nodes(), kernel->in_kernels().at(0))) {
323 continue;
324 }
325
326 TransInfoPair post_trans_info;
327 if (GetTransposeInfo(kernel, &post_trans_info) != RET_OK) {
328 MS_LOG(DEBUG) << "The kernel " << kernel->name() << " isn't transpose and can't be eliminated.";
329 continue;
330 }
331
332 TransInfoPair pre_trans_info;
333 if (GetTransposeInfo(pre_kernel, &pre_trans_info) != RET_OK) {
334 MS_LOG(DEBUG) << "The kernel " << pre_kernel->name() << " isn't transpose and can't be eliminated.";
335 continue;
336 }
337
338 if (pre_trans_info.dst_format_ != post_trans_info.src_format_) {
339 MS_LOG(DEBUG) << "Two transposes" << pre_kernel->name() << " and " << kernel->name()
340 << " connected front and back but with unsatisfied perm and can not be eliminated."
341 << "Maybe we can fuse them into one transpose.";
342 continue;
343 }
344
345 graph_changed_ = true;
346 // record the next kernel to update iterator
347 auto next_kernel = (kernel_iter == kernels->end()) ? nullptr : (*kernel_iter);
348
349 int ret = RET_OK;
350 if (pre_trans_info.src_format_ == post_trans_info.dst_format_) {
351 // pattern opposite, like: nhwc2nchw & nchw2nhwc -> none
352 ret = TransFullyFusion(subgraph, pre_kernel, kernel);
353 } else {
354 // pattern, the previous dest format and the post source format are same
355 // op1: format1 -> format2, op2: format2 -> format3, like: nhwc2nchw & nchw2nc4hw4 -> nhwc2nc4hw4
356 TransInfoPair new_trans_info(pre_trans_info.src_format_, post_trans_info.dst_format_);
357 ret = TransHeadTailFusion(subgraph, pre_kernel, kernel, new_trans_info, this->create_format_transpose_func_);
358 }
359 if (ret != RET_OK) {
360 MS_LOG(ERROR) << "Fusion " << pre_kernel->name() << " and " << kernel->name() << " failed";
361 return RET_ERROR;
362 }
363
364 // The dropped kernel may be in front of the kernel, update kernel iterator.
365 kernel_iter = (next_kernel == nullptr) ? kernels->end() : (find(kernels->begin(), kernels->end(), next_kernel));
366 }
367 return RET_OK;
368 }
369
RunPass(kernel::SubGraphKernel * graph,std::vector<lite::Tensor * > * tensors)370 int EliminateTranspose::RunPass(kernel::SubGraphKernel *graph, std::vector<lite::Tensor *> *tensors) {
371 int pass_count = 0;
372 while (graph_changed_ && pass_count < max_pass_count_) {
373 graph_changed_ = false;
374
375 auto ret = DoubleTransposeFusion(graph);
376 if (ret != RET_OK) {
377 MS_LOG(ERROR) << "Double transpose fusion failed in runtime pass.";
378 return RET_ERROR;
379 }
380
381 ret = EliminateForSingleKernel(graph, tensors);
382 if (ret != RET_OK) {
383 MS_LOG(ERROR) << "Eliminate for single kernel failed in runtime pass.";
384 return RET_ERROR;
385 }
386
387 ret = HorizontalTransposeFusionPass(graph);
388 if (ret != RET_OK) {
389 MS_LOG(ERROR) << "HorizontalTransposeFusionPass failed in runtime pass.";
390 return RET_ERROR;
391 }
392
393 pass_count++;
394 }
395
396 auto ret = graph->TopologicalSortNodes();
397 if (ret != RET_OK) {
398 MS_LOG(ERROR) << "Topological sort kernels failed.";
399 return RET_ERROR;
400 }
401 return RET_OK;
402 }
403 } // namespace mindspore::lite::pass
404