1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/pass/format_pass/insert_transpose.h"
18 #include "src/litert/pass/format_pass/format_utils.h"
19 #include "src/litert/kernel_exec_util.h"
20 #include "nnacl/base/format_transpose.h"
21
22 namespace mindspore::lite::pass {
TransposeConstData(kernel::KernelExec * kernel,size_t index)23 int InsertTranspose::TransposeConstData(kernel::KernelExec *kernel, size_t index) {
24 lite::Tensor *tensor = kernel->in_tensors().at(index);
25 Format except_format = kernel->desc().format;
26 if (tensor->format() == except_format) {
27 return RET_OK;
28 }
29
30 if (tensor->allocator() != nullptr) {
31 MS_LOG(ERROR) << "Const data allocator invalid.";
32 return RET_ERROR;
33 }
34
35 void *buffer = malloc(tensor->Size());
36 if (buffer == nullptr) {
37 MS_LOG(ERROR) << "malloc transpose data failed";
38 return RET_ERROR;
39 }
40 auto ret = TransData(tensor->data(), buffer, (FormatC)(tensor->format()), (FormatC)except_format,
41 static_cast<TypeIdC>(tensor->data_type()), tensor->Batch(), tensor->Channel(),
42 tensor->Height() * tensor->Width());
43 if (ret != RET_OK) {
44 return ret;
45 }
46
47 tensor->FreeData();
48 tensor->set_data(buffer, true);
49 if (!TransTensorShapeAndFormat(tensor, except_format)) {
50 MS_LOG(ERROR) << "unsupported except format: " << except_format;
51 return RET_ERROR;
52 }
53 return RET_OK;
54 }
55
RunPass(kernel::SubGraphKernel * graph,std::vector<lite::Tensor * > * tensors)56 int InsertTranspose::RunPass(kernel::SubGraphKernel *graph, std::vector<lite::Tensor *> *tensors) {
57 auto kernels = graph->nodes();
58
59 auto origin_kernel_size = kernels.size();
60 for (size_t kernel_index = 0; kernel_index < origin_kernel_size; kernel_index++) {
61 kernel::KernelExec *kernel = kernels.at(kernel_index);
62 CHECK_NULL_RETURN(kernel);
63 Format kernel_format = kernel->desc().format;
64 if (kernel_format == format_) {
65 continue;
66 }
67
68 // to be realized: 1. get type from kernel; 2. flag for transpose weight.
69 std::string type_name = kernel::TypeName(kernel->type());
70 auto find_result = cloud_format_kernel_list.find(type_name);
71 if (find_result == cloud_format_kernel_list.end()) {
72 MS_LOG(INFO) << "Kernel(" << kernel->name() << ") has different format(" << FormatEnumToString(kernel_format)
73 << ") with graph format(" << FormatEnumToString(format_) << "), but not in insert-transpose white "
74 << "list and will not insert transpose kernel, type name: " << type_name;
75 continue;
76 }
77
78 auto insert_input_list = find_result->second;
79 for (auto index : insert_input_list) {
80 if (index >= kernel->in_tensors().size()) {
81 continue;
82 }
83
84 if (kernel->in_tensors().at(index)->IsConst()) {
85 (void)TransposeConstData(kernel, index);
86 continue;
87 }
88 auto ret = InsertPreTranspose(graph, kernel, tensors, TransInfoPair(format_, kernel_format), index,
89 create_format_transpose_func_);
90 if (ret != RET_OK) {
91 MS_LOG(ERROR) << "Insert pre transpose for op: " << kernel->name() << ", index: " << index << ", failed";
92 return RET_ERROR;
93 }
94 }
95
96 for (size_t i = 0; i < kernel->out_kernels().size(); i++) {
97 auto ret = InsertPostTranspose(graph, kernel, tensors, TransInfoPair(kernel_format, format_), i,
98 create_format_transpose_func_);
99 if (ret != RET_OK) {
100 MS_LOG(ERROR) << "Insert post transpose for op: " << kernel->name() << ", index: " << i << ", failed";
101 return RET_ERROR;
102 }
103 }
104 // graph output node has no output kernels, take care of these nodes
105 if (IsContain(graph->out_nodes(), kernel)) {
106 for (size_t i = 0; i < kernel->out_tensors().size(); i++) {
107 auto ret = InsertPostTranspose(graph, kernel, tensors, TransInfoPair(kernel_format, format_), i,
108 create_format_transpose_func_);
109 if (ret != RET_OK) {
110 MS_LOG(ERROR) << "Insert post transpose for op: " << kernel->name() << ", index: " << i << ", failed";
111 return RET_ERROR;
112 }
113 }
114 }
115 MS_LOG(INFO) << "Insert transpose before and after node: " << kernel->name();
116
117 graph->SetInNodes(kernel::KernelExecUtil::SubgraphInputNodes(graph->nodes()));
118 graph->SetOutNodes(kernel::KernelExecUtil::SubgraphOutputNodes(graph->nodes()));
119
120 auto ret = graph->TopologicalSortNodes();
121 if (ret != RET_OK) {
122 MS_LOG(ERROR) << "Topological sort kernels failed.";
123 return RET_ERROR;
124 }
125 }
126
127 return RET_OK;
128 }
129 } // namespace mindspore::lite::pass
130