• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/pass/format_pass/insert_transpose.h"
18 #include "src/litert/pass/format_pass/format_utils.h"
19 #include "src/litert/kernel_exec_util.h"
20 #include "nnacl/base/format_transpose.h"
21 
22 namespace mindspore::lite::pass {
TransposeConstData(kernel::KernelExec * kernel,size_t index)23 int InsertTranspose::TransposeConstData(kernel::KernelExec *kernel, size_t index) {
24   lite::Tensor *tensor = kernel->in_tensors().at(index);
25   Format except_format = kernel->desc().format;
26   if (tensor->format() == except_format) {
27     return RET_OK;
28   }
29 
30   if (tensor->allocator() != nullptr) {
31     MS_LOG(ERROR) << "Const data allocator invalid.";
32     return RET_ERROR;
33   }
34 
35   void *buffer = malloc(tensor->Size());
36   if (buffer == nullptr) {
37     MS_LOG(ERROR) << "malloc transpose data failed";
38     return RET_ERROR;
39   }
40   auto ret = TransData(tensor->data(), buffer, (FormatC)(tensor->format()), (FormatC)except_format,
41                        static_cast<TypeIdC>(tensor->data_type()), tensor->Batch(), tensor->Channel(),
42                        tensor->Height() * tensor->Width());
43   if (ret != RET_OK) {
44     return ret;
45   }
46 
47   tensor->FreeData();
48   tensor->set_data(buffer, true);
49   if (!TransTensorShapeAndFormat(tensor, except_format)) {
50     MS_LOG(ERROR) << "unsupported except format: " << except_format;
51     return RET_ERROR;
52   }
53   return RET_OK;
54 }
55 
RunPass(kernel::SubGraphKernel * graph,std::vector<lite::Tensor * > * tensors)56 int InsertTranspose::RunPass(kernel::SubGraphKernel *graph, std::vector<lite::Tensor *> *tensors) {
57   auto kernels = graph->nodes();
58 
59   auto origin_kernel_size = kernels.size();
60   for (size_t kernel_index = 0; kernel_index < origin_kernel_size; kernel_index++) {
61     kernel::KernelExec *kernel = kernels.at(kernel_index);
62     CHECK_NULL_RETURN(kernel);
63     Format kernel_format = kernel->desc().format;
64     if (kernel_format == format_) {
65       continue;
66     }
67 
68     // to be realized: 1. get type from kernel; 2. flag for transpose weight.
69     std::string type_name = kernel::TypeName(kernel->type());
70     auto find_result = cloud_format_kernel_list.find(type_name);
71     if (find_result == cloud_format_kernel_list.end()) {
72       MS_LOG(INFO) << "Kernel(" << kernel->name() << ") has different format(" << FormatEnumToString(kernel_format)
73                    << ") with graph format(" << FormatEnumToString(format_) << "), but not in insert-transpose white "
74                    << "list and will not insert transpose kernel, type name: " << type_name;
75       continue;
76     }
77 
78     auto insert_input_list = find_result->second;
79     for (auto index : insert_input_list) {
80       if (index >= kernel->in_tensors().size()) {
81         continue;
82       }
83 
84       if (kernel->in_tensors().at(index)->IsConst()) {
85         (void)TransposeConstData(kernel, index);
86         continue;
87       }
88       auto ret = InsertPreTranspose(graph, kernel, tensors, TransInfoPair(format_, kernel_format), index,
89                                     create_format_transpose_func_);
90       if (ret != RET_OK) {
91         MS_LOG(ERROR) << "Insert pre transpose for op: " << kernel->name() << ", index: " << index << ", failed";
92         return RET_ERROR;
93       }
94     }
95 
96     for (size_t i = 0; i < kernel->out_kernels().size(); i++) {
97       auto ret = InsertPostTranspose(graph, kernel, tensors, TransInfoPair(kernel_format, format_), i,
98                                      create_format_transpose_func_);
99       if (ret != RET_OK) {
100         MS_LOG(ERROR) << "Insert post transpose for op: " << kernel->name() << ", index: " << i << ", failed";
101         return RET_ERROR;
102       }
103     }
104     // graph output node has no output kernels, take care of these nodes
105     if (IsContain(graph->out_nodes(), kernel)) {
106       for (size_t i = 0; i < kernel->out_tensors().size(); i++) {
107         auto ret = InsertPostTranspose(graph, kernel, tensors, TransInfoPair(kernel_format, format_), i,
108                                        create_format_transpose_func_);
109         if (ret != RET_OK) {
110           MS_LOG(ERROR) << "Insert post transpose for op: " << kernel->name() << ", index: " << i << ", failed";
111           return RET_ERROR;
112         }
113       }
114     }
115     MS_LOG(INFO) << "Insert transpose before and after node: " << kernel->name();
116 
117     graph->SetInNodes(kernel::KernelExecUtil::SubgraphInputNodes(graph->nodes()));
118     graph->SetOutNodes(kernel::KernelExecUtil::SubgraphOutputNodes(graph->nodes()));
119 
120     auto ret = graph->TopologicalSortNodes();
121     if (ret != RET_OK) {
122       MS_LOG(ERROR) << "Topological sort kernels failed.";
123       return RET_ERROR;
124     }
125   }
126 
127   return RET_OK;
128 }
129 }  // namespace mindspore::lite::pass
130