1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/pass/format_pass/format_pass.h"
18 #include "src/litert/pass/format_pass/insert_transpose.h"
19 #include "src/litert/pass/format_pass/eliminate_transpose.h"
20 #ifdef ENABLE_MULTI_LAYOUT
21 #include "src/litert/kernel_registry.h"
22 #include "nnacl/format_transpose_parameter.h"
23 #endif
24 #include "src/common/draw/drawer.h"
25
26 namespace mindspore::lite::pass {
27 #ifdef ENABLE_MULTI_LAYOUT
28 namespace {
DefaultCreateFormatTransFunc(Tensor * input,Tensor * output,const TransInfoPair & trans_info,const std::string & name,const lite::InnerContext * ctx,const kernel::KernelKey & desc)29 kernel::KernelExec *DefaultCreateFormatTransFunc(Tensor *input, Tensor *output, const TransInfoPair &trans_info,
30 const std::string &name, const lite::InnerContext *ctx,
31 const kernel::KernelKey &desc) {
32 auto param = reinterpret_cast<FormatTransposeParameter *>(malloc(sizeof(FormatTransposeParameter)));
33 if (param == nullptr) {
34 MS_LOG(ERROR) << "Malloc FormatTransposeParameter failed.";
35 return nullptr;
36 }
37 (void)memset(param, 0, sizeof(FormatTransposeParameter));
38 param->op_parameter_.type_ = static_cast<int>(schema::PrimitiveType_FormatTranspose);
39 param->src_format_ = static_cast<FormatC>((trans_info.src_format_));
40 param->dst_format_ = static_cast<FormatC>((trans_info.dst_format_));
41 kernel::KernelKey format_transpose_key = desc;
42 format_transpose_key.type = schema::PrimitiveType_FormatTranspose;
43 format_transpose_key.format = NHWC;
44 format_transpose_key.data_type = input->data_type();
45
46 kernel::MSKernel *kernel_impl;
47 auto lite_kernel = KernelRegistry::GetInstance()->GetLiteKernel({input}, {output}, ctx, &format_transpose_key,
48 reinterpret_cast<OpParameter *>(param));
49 if (lite_kernel == nullptr) {
50 MS_LOG(ERROR) << "Create format-transpose lite-kernel failed.";
51 free(param);
52 return nullptr;
53 }
54 kernel_impl = lite_kernel;
55 auto *kernel_exec = new (std::nothrow) kernel::KernelExec(std::shared_ptr<kernel::MSKernel>(kernel_impl));
56 if (kernel_exec == nullptr) {
57 MS_LOG(ERROR) << "Create format-transpose kernel-exec failed.";
58 return nullptr;
59 }
60 kernel_exec->set_desc(format_transpose_key);
61 kernel_exec->set_context(ctx);
62 kernel_exec->set_name(name);
63 return kernel_exec;
64 }
65 } // namespace
66 #endif
67
AddPass(const FormatPassPtr & pass)68 int FormatOptimize::AddPass(const FormatPassPtr &pass) {
69 CHECK_NULL_RETURN(pass);
70 pass_list_.push_back(pass);
71 return RET_OK;
72 }
73
RunPass(kernel::SubGraphKernel * graph,std::vector<Tensor * > * tensors)74 int FormatOptimize::RunPass(kernel::SubGraphKernel *graph, std::vector<Tensor *> *tensors) {
75 for (const FormatPassPtr &pass : pass_list_) {
76 CHECK_NULL_RETURN(pass);
77
78 auto status = pass->RunPass(graph, tensors);
79 if (status != RET_OK) {
80 MS_LOG(ERROR) << "Run pass failed";
81 return status;
82 }
83 DrawDot(graph, pass->name());
84 }
85 return RET_OK;
86 }
87
DoFormatPass(std::vector<mindspore::kernel::KernelExec * > * subgraph_list,std::vector<mindspore::lite::Tensor * > * tensors,mindspore::Format graph_format,const CreateFormatTransposeFunc & create_format_transpose_func)88 int DoFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list,
89 std::vector<mindspore::lite::Tensor *> *tensors, mindspore::Format graph_format,
90 const CreateFormatTransposeFunc &create_format_transpose_func) {
91 for (const auto &subgraph : *subgraph_list) {
92 FormatOptimizePtr optimize = std::make_shared<FormatOptimize>();
93
94 (void)optimize->AddPass(std::make_shared<InsertTranspose>(graph_format, create_format_transpose_func));
95 (void)optimize->AddPass(std::make_shared<EliminateTranspose>(graph_format, create_format_transpose_func));
96
97 auto graph = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
98 auto ret = optimize->RunPass(graph, tensors);
99 if (ret != RET_OK) {
100 MS_LOG(ERROR) << "Runtime format pass failed.";
101 return RET_ERROR;
102 }
103 }
104
105 return RET_OK;
106 }
107
RuntimeFormatPass(std::vector<mindspore::kernel::KernelExec * > * subgraph_list,std::vector<mindspore::lite::Tensor * > * tensors,mindspore::Format graph_format,const CreateFormatTransposeFunc & create_format_transpose_func)108 int RuntimeFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list,
109 std::vector<mindspore::lite::Tensor *> *tensors, mindspore::Format graph_format,
110 const CreateFormatTransposeFunc &create_format_transpose_func) {
111 #ifndef ENABLE_MULTI_LAYOUT
112 return RET_OK;
113 #else
114 if (create_format_transpose_func == nullptr) {
115 return DoFormatPass(subgraph_list, tensors, graph_format, DefaultCreateFormatTransFunc);
116 } else {
117 return DoFormatPass(subgraph_list, tensors, graph_format, create_format_transpose_func);
118 }
119 #endif
120 }
121 } // namespace mindspore::lite::pass
122