• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <set>
18 #include <queue>
19 #include "src/litert/pass/format_pass/eliminate_transpose.h"
20 #include "src/litert/kernel_exec_util.h"
21 
22 namespace mindspore::lite::pass {
TransFullyFusion(kernel::SubGraphKernel * subgraph,kernel::KernelExec * trans_kernel0,kernel::KernelExec * trans_kernel1)23 int TransFullyFusion(kernel::SubGraphKernel *subgraph, kernel::KernelExec *trans_kernel0,
24                      kernel::KernelExec *trans_kernel1) {
25   CHECK_NULL_RETURN(trans_kernel0);
26   CHECK_NULL_RETURN(trans_kernel1);
27   auto in_tensor = trans_kernel0->in_tensors().at(0);
28 
29   auto out_tensor = trans_kernel1->out_tensors().at(0);
30   auto in_kernel = kernel::KernelExecUtil::FindInKernelForInTensor(trans_kernel0, in_tensor);
31   auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(trans_kernel1, out_tensor);
32   subgraph->UpdateInOutKernels(in_kernel, out_kernels, trans_kernel0, trans_kernel1);
33   auto ret = subgraph->UpdateInOutTensors(in_kernel, out_kernels, in_tensor, out_tensor, true);
34   if (ret != RET_OK) {
35     MS_LOG(ERROR) << "Update tensor failed when fusing kernel " << trans_kernel0->name() << " and "
36                   << trans_kernel1->name();
37     return RET_ERROR;
38   }
39   subgraph->DropNode(trans_kernel1);
40   delete trans_kernel1;
41   if (trans_kernel0->out_kernels().empty() && !IsContain(subgraph->out_tensors(), trans_kernel0->out_tensors().at(0))) {
42     subgraph->DropNode(trans_kernel0);
43     delete trans_kernel0;
44   }
45   return RET_OK;
46 }
TransHeadTailFusion(kernel::SubGraphKernel * subgraph,kernel::KernelExec * trans_kernel0,kernel::KernelExec * trans_kernel1,const TransInfoPair & trans_info,const CreateFormatTransposeFunc & create_format_transpose_func)47 int TransHeadTailFusion(kernel::SubGraphKernel *subgraph, kernel::KernelExec *trans_kernel0,
48                         kernel::KernelExec *trans_kernel1, const TransInfoPair &trans_info,
49                         const CreateFormatTransposeFunc &create_format_transpose_func) {
50   CHECK_NULL_RETURN(trans_kernel0);
51   CHECK_NULL_RETURN(trans_kernel1);
52   CHECK_NULL_RETURN(create_format_transpose_func);
53   auto ctx = trans_kernel0->Context();
54   auto desc = trans_kernel0->desc();
55   auto in_tensor = trans_kernel0->in_tensors().at(0);
56   auto out_tensor = trans_kernel1->out_tensors().at(0);
57   auto in_kernel = kernel::KernelExecUtil::FindInKernelForInTensor(trans_kernel0, in_tensor);
58   auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(trans_kernel1, out_tensor);
59   subgraph->UpdateInOutKernels(in_kernel, out_kernels, trans_kernel0, trans_kernel1);
60   // new trans kernel: src_format -> dst_format
61   auto trans_name = trans_kernel0->name() + "_and_" + trans_kernel1->name() + "_fusion";
62   auto kernel = create_format_transpose_func(in_tensor, out_tensor, trans_info, trans_name, ctx, desc);
63   CHECK_NULL_RETURN(kernel);
64   if (in_kernel != nullptr) {
65     in_kernel->AddOutKernel(kernel);
66     kernel->AddInKernel(in_kernel);
67   }
68   for (const auto &out_kernel : out_kernels) {
69     if (in_kernel != nullptr) {
70       in_kernel->RemoveOutKernel(out_kernel);
71       out_kernel->RemoveInKernel(in_kernel);
72     }
73     out_kernel->AddInKernel(kernel);
74     kernel->AddOutKernel(out_kernel);
75   }
76   subgraph->nodes().push_back(kernel);
77 
78   subgraph->DropNode(trans_kernel1);
79   delete trans_kernel1;
80   if (trans_kernel0->out_kernels().empty() && !IsContain(subgraph->out_tensors(), trans_kernel0->out_tensors().at(0))) {
81     subgraph->DropNode(trans_kernel0);
82     delete trans_kernel0;
83   }
84   return RET_OK;
85 }
86 
PackConstData(Tensor * tensor,const TransInfoPair & pre_trans)87 int PackConstData(Tensor *tensor, const TransInfoPair &pre_trans) {
88   if (tensor->shape().size() != 4) {
89     MS_LOG(ERROR) << "Pack const data only valid for 4 dims tensor.";
90     return RET_OK;
91   }
92   auto allocator = tensor->allocator();
93   auto original_data = tensor->data();
94   auto original_own_data = tensor->own_data();
95 
96   if (!TransTensorShapeAndFormat(tensor, pre_trans.dst_format_)) {
97     MS_LOG(ERROR) << "Transpose tensor shape and format failed";
98     return RET_ERROR;
99   }
100   tensor->set_data(nullptr);
101 
102   auto ret = tensor->MallocData();
103   if (ret != RET_OK) {
104     MS_LOG(ERROR) << "Malloc new format data failed";
105     return ret;
106   }
107 
108   if (original_own_data) {
109     if (allocator != nullptr) {
110       allocator->Free(original_data);
111     } else {
112       free(original_data);
113     }
114   }
115   MS_LOG(ERROR) << "Can't call TransData function.";
116   return RET_ERROR;
117 }
118 
DoPreFusion(kernel::SubGraphKernel * subgraph,kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & pre_trans,const CreateFormatTransposeFunc & create_format_transpose_func)119 int DoPreFusion(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
120                 const TransInfoPair &pre_trans, const CreateFormatTransposeFunc &create_format_transpose_func) {
121   CHECK_NULL_RETURN(create_format_transpose_func);
122   for (size_t i = 0; i < kernel->in_tensors().size(); i++) {
123     auto in_tensor = kernel->in_tensors().at(i);
124     if (in_tensor->IsConst()) {
125       auto ret = PackConstData(in_tensor, pre_trans);
126       if (ret != RET_OK) {
127         MS_LOG(ERROR) << "Pack tensor " << in_tensor->tensor_name() << " data failed.";
128         return RET_ERROR;
129       }
130       continue;
131     }
132     auto ret =
133       InsertPreTranspose(subgraph, kernel, all_tensors, TransInfoPair(pre_trans.dst_format_, pre_trans.src_format_), i,
134                          create_format_transpose_func);
135     if (ret != RET_OK) {
136       MS_LOG(ERROR) << "Insert pre transpose for " << kernel->name() << "(index: " << i
137                     << ") while eliminating transposes crossing kernel failed";
138       return RET_ERROR;
139     }
140   }
141   return RET_OK;
142 }
143 
DoPostFusion(kernel::SubGraphKernel * subgraph,const kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & post_trans,const CreateFormatTransposeFunc & create_format_transpose_func)144 int DoPostFusion(kernel::SubGraphKernel *subgraph, const kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
145                  const TransInfoPair &post_trans, const CreateFormatTransposeFunc &create_format_transpose_func) {
146   CHECK_NULL_RETURN(create_format_transpose_func);
147   for (size_t i = 0; i < kernel->out_tensors().size(); i++) {
148     auto tensor = kernel->out_tensors().at(i);
149     auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(kernel, tensor);
150 
151     std::vector<kernel::KernelExec *> to_deletes;
152     for (const auto &out_kernel : out_kernels) {
153       TransInfoPair out_kernel_trans;
154       auto ret = GetTransposeInfo(out_kernel, &out_kernel_trans);
155       if (ret == RET_OK && IsSameTranspose(post_trans, out_kernel_trans)) {
156         (void)to_deletes.emplace_back(out_kernel);
157         continue;
158       }
159       auto in_tensor_of_out_kernel_idxes = out_kernel->FindAllInTensorIndex(tensor);
160       for (auto &in_tensor_of_out_kernel_idx : in_tensor_of_out_kernel_idxes) {
161         ret = InsertPreTranspose(subgraph, out_kernel, all_tensors,
162                                  TransInfoPair(post_trans.dst_format_, post_trans.src_format_),
163                                  in_tensor_of_out_kernel_idx, create_format_transpose_func);
164         if (ret != RET_OK) {
165           MS_LOG(ERROR) << "Insert pre transpose kernel for op: " << out_kernel->name() << " input tensor "
166                         << in_tensor_of_out_kernel_idx << " failed.";
167           return RET_ERROR;
168         }
169       }
170     }
171     for (auto &to_delete : to_deletes) {
172       auto ret = subgraph->DeleteSingleWayNode(to_delete, false);
173       if (ret != RET_OK) {
174         MS_LOG(ERROR) << "Delete kernel: " << to_delete->name() << " failed.";
175         return RET_ERROR;
176       }
177     }
178   }
179   return RET_OK;
180 }
181 
EliminateForSingleKernel(kernel::SubGraphKernel * subgraph,std::vector<Tensor * > * all_tensors)182 int EliminateTranspose::EliminateForSingleKernel(kernel::SubGraphKernel *subgraph, std::vector<Tensor *> *all_tensors) {
183   auto kernels = &(subgraph->nodes());
184   auto kernel_iter = kernels->begin();
185   while (kernel_iter != kernels->end()) {
186     auto kernel = *kernel_iter;
187     CHECK_NULL_RETURN(kernel);
188     TransInfoPair pre_trans;
189     TransInfoPair post_trans;
190     if (!transpose_strategy_.CrossKernelFusionPreCheck(kernel, &pre_trans, &post_trans)) {
191       (void)kernel_iter++;
192       continue;
193     }
194     auto ret = TransposeStrategy::TryTransKernelAxis(kernel, post_trans);
195     if (ret == RET_NO_CHANGE) {
196       // Some kernel can not be fusion although CrossKernelFusionPreCheck is successful. For example, whether crop can
197       // transpose axis depend on axis attribute of crop primitive.
198       (void)kernel_iter++;
199       continue;
200     }
201     if (ret != RET_OK) {
202       MS_LOG(ERROR) << "Change kernel axis " << kernel->name() << " failed.";
203       return RET_ERROR;
204     }
205 
206     graph_changed_ = true;
207     ret = DoPreFusion(subgraph, kernel, all_tensors, pre_trans, this->create_format_transpose_func_);
208     if (ret != RET_OK) {
209       MS_LOG(ERROR) << "Fusion for pre transpose of " << kernel->name() << " failed.";
210       return RET_ERROR;
211     }
212     ret = DoPostFusion(subgraph, kernel, all_tensors, post_trans, this->create_format_transpose_func_);
213     if (ret != RET_OK) {
214       MS_LOG(ERROR) << "Fusion for post transpose of " << kernel->name() << " failed.";
215       return RET_ERROR;
216     }
217     kernel_iter = find(kernels->begin(), kernels->end(), kernel);
218     (void)kernel_iter++;
219     MS_LOG(INFO) << "Fuse transpose across: " << kernel->name();
220   }
221   return RET_OK;
222 }
223 
HorizontalTransposeFusionPass(kernel::SubGraphKernel * subgraph)224 int EliminateTranspose::HorizontalTransposeFusionPass(kernel::SubGraphKernel *subgraph) {
225   auto in_tensors = subgraph->in_tensors();
226   std::queue<lite::Tensor *> tensor_queue;
227   for (const auto &tensor : in_tensors) {
228     tensor_queue.push(tensor);
229   }
230   std::set<lite::Tensor *> visited;
231   while (!tensor_queue.empty()) {
232     auto tensor = tensor_queue.front();
233     tensor_queue.pop();
234     visited.insert(tensor);
235     auto in_kernel = kernel::KernelExecUtil::FindInKernelForTensorInSubGraph(tensor, subgraph);
236     auto out_kernels = kernel::KernelExecUtil::FindOutKernelsForTensorInSubGraph(tensor, subgraph);
237     for (const auto &out_kernel : out_kernels) {
238       for (const auto &out_tensor : out_kernel->out_tensors()) {
239         if (visited.find(out_tensor) == visited.end()) {
240           tensor_queue.push(out_tensor);
241         }
242       }
243     }
244 
245     TransInfoPair post_trans;
246     auto count = transpose_strategy_.GetTransCount(out_kernels, &post_trans);
247     if (count <= 1) {
248       continue;
249     }
250 
251     graph_changed_ = true;
252 
253     kernel::KernelExec *reserve_kernel = nullptr;
254     std::vector<kernel::KernelExec *> to_deletes;
255     for (const auto &out_kernel : out_kernels) {
256       TransInfoPair tmp_trans;
257       if (GetTransposeInfo(out_kernel, &tmp_trans) != RET_OK || !IsSameTranspose(post_trans, tmp_trans)) {
258         continue;
259       }
260       if (reserve_kernel == nullptr) {
261         // firstly set value
262         reserve_kernel = out_kernel;
263         continue;
264       }
265       if (IsContain(subgraph->out_tensors(), out_kernel->out_tensors().at(0))) {
266         to_deletes.push_back(reserve_kernel);
267         reserve_kernel = out_kernel;
268       } else {
269         to_deletes.push_back(out_kernel);
270       }
271     }
272     auto reserve_tensor = reserve_kernel->out_tensors().at(0);
273 
274     for (const auto &to_delete : to_deletes) {
275       if (to_delete == reserve_kernel) {
276         continue;
277       }
278 
279       if (in_kernel != nullptr) {
280         in_kernel->RemoveOutKernel(to_delete);
281         to_delete->RemoveInKernel(in_kernel);
282       }
283 
284       auto post_kernels = kernel::KernelExecUtil::FindOutKernelsForOutTensor(to_delete, to_delete->out_tensors().at(0));
285       for (const auto &post : post_kernels) {
286         to_delete->RemoveOutKernel(post);
287         post->RemoveInKernel(to_delete);
288 
289         post->AddInKernel(reserve_kernel);
290         reserve_kernel->AddOutKernel(post);
291 
292         auto input_indexes = post->FindAllInTensorIndex(to_delete->out_tensors().at(0));
293         for (auto &input_index : input_indexes) {
294           post->set_in_tensor(reserve_tensor, input_index);
295         }
296       }
297       subgraph->DropNode(to_delete);
298       delete to_delete;
299     }
300     if (in_kernel != nullptr) {
301       MS_LOG(INFO) << "Fuse horizontal-transposes after: " << in_kernel->name();
302     } else {
303       MS_LOG(INFO) << "Fuse horizontal-transposes on tensor: " << tensor->tensor_name();
304     }
305   }
306   return RET_OK;
307 }
308 
DoubleTransposeFusion(kernel::SubGraphKernel * subgraph)309 int EliminateTranspose::DoubleTransposeFusion(kernel::SubGraphKernel *subgraph) {
310   auto kernels = &(subgraph->nodes());
311   auto kernel_iter = kernels->begin();
312   while (kernel_iter != kernels->end()) {
313     auto &kernel = *kernel_iter;
314     CHECK_NULL_RETURN(kernel);
315     (void)kernel_iter++;
316 
317     if (kernel->in_kernels().size() != 1) {
318       continue;
319     }
320 
321     auto pre_kernel = kernel->in_kernels().at(0);
322     if (!IsContain(subgraph->nodes(), kernel->in_kernels().at(0))) {
323       continue;
324     }
325 
326     TransInfoPair post_trans_info;
327     if (GetTransposeInfo(kernel, &post_trans_info) != RET_OK) {
328       MS_LOG(DEBUG) << "The kernel " << kernel->name() << " isn't transpose and can't be eliminated.";
329       continue;
330     }
331 
332     TransInfoPair pre_trans_info;
333     if (GetTransposeInfo(pre_kernel, &pre_trans_info) != RET_OK) {
334       MS_LOG(DEBUG) << "The kernel " << pre_kernel->name() << " isn't transpose and can't be eliminated.";
335       continue;
336     }
337 
338     if (pre_trans_info.dst_format_ != post_trans_info.src_format_) {
339       MS_LOG(DEBUG) << "Two transposes" << pre_kernel->name() << " and " << kernel->name()
340                     << " connected front and back but with unsatisfied perm and can not be eliminated."
341                     << "Maybe we can fuse them into one transpose.";
342       continue;
343     }
344 
345     graph_changed_ = true;
346     // record the next kernel to update iterator
347     auto next_kernel = (kernel_iter == kernels->end()) ? nullptr : (*kernel_iter);
348 
349     int ret = RET_OK;
350     if (pre_trans_info.src_format_ == post_trans_info.dst_format_) {
351       // pattern opposite, like: nhwc2nchw & nchw2nhwc -> none
352       ret = TransFullyFusion(subgraph, pre_kernel, kernel);
353     } else {
354       // pattern, the previous dest format and the post source format are same
355       // op1: format1 -> format2, op2: format2 -> format3, like: nhwc2nchw & nchw2nc4hw4 -> nhwc2nc4hw4
356       TransInfoPair new_trans_info(pre_trans_info.src_format_, post_trans_info.dst_format_);
357       ret = TransHeadTailFusion(subgraph, pre_kernel, kernel, new_trans_info, this->create_format_transpose_func_);
358     }
359     if (ret != RET_OK) {
360       MS_LOG(ERROR) << "Fusion " << pre_kernel->name() << " and " << kernel->name() << " failed";
361       return RET_ERROR;
362     }
363 
364     // The dropped kernel may be in front of the kernel, update kernel iterator.
365     kernel_iter = (next_kernel == nullptr) ? kernels->end() : (find(kernels->begin(), kernels->end(), next_kernel));
366   }
367   return RET_OK;
368 }
369 
RunPass(kernel::SubGraphKernel * graph,std::vector<lite::Tensor * > * tensors)370 int EliminateTranspose::RunPass(kernel::SubGraphKernel *graph, std::vector<lite::Tensor *> *tensors) {
371   int pass_count = 0;
372   while (graph_changed_ && pass_count < max_pass_count_) {
373     graph_changed_ = false;
374 
375     auto ret = DoubleTransposeFusion(graph);
376     if (ret != RET_OK) {
377       MS_LOG(ERROR) << "Double transpose fusion failed in runtime pass.";
378       return RET_ERROR;
379     }
380 
381     ret = EliminateForSingleKernel(graph, tensors);
382     if (ret != RET_OK) {
383       MS_LOG(ERROR) << "Eliminate for single kernel failed in runtime pass.";
384       return RET_ERROR;
385     }
386 
387     ret = HorizontalTransposeFusionPass(graph);
388     if (ret != RET_OK) {
389       MS_LOG(ERROR) << "HorizontalTransposeFusionPass failed in runtime pass.";
390       return RET_ERROR;
391     }
392 
393     pass_count++;
394   }
395 
396   auto ret = graph->TopologicalSortNodes();
397   if (ret != RET_OK) {
398     MS_LOG(ERROR) << "Topological sort kernels failed.";
399     return RET_ERROR;
400   }
401   return RET_OK;
402 }
403 }  // namespace mindspore::lite::pass
404