• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/pass/format_pass/format_pass.h"
18 #include "src/litert/pass/format_pass/insert_transpose.h"
19 #include "src/litert/pass/format_pass/eliminate_transpose.h"
20 #ifdef ENABLE_MULTI_LAYOUT
21 #include "src/litert/kernel_registry.h"
22 #include "nnacl/format_transpose_parameter.h"
23 #endif
24 #include "src/common/draw/drawer.h"
25 
26 namespace mindspore::lite::pass {
27 #ifdef ENABLE_MULTI_LAYOUT
28 namespace {
DefaultCreateFormatTransFunc(Tensor * input,Tensor * output,const TransInfoPair & trans_info,const std::string & name,const lite::InnerContext * ctx,const kernel::KernelKey & desc)29 kernel::KernelExec *DefaultCreateFormatTransFunc(Tensor *input, Tensor *output, const TransInfoPair &trans_info,
30                                                  const std::string &name, const lite::InnerContext *ctx,
31                                                  const kernel::KernelKey &desc) {
32   auto param = reinterpret_cast<FormatTransposeParameter *>(malloc(sizeof(FormatTransposeParameter)));
33   if (param == nullptr) {
34     MS_LOG(ERROR) << "Malloc FormatTransposeParameter failed.";
35     return nullptr;
36   }
37   (void)memset(param, 0, sizeof(FormatTransposeParameter));
38   param->op_parameter_.type_ = static_cast<int>(schema::PrimitiveType_FormatTranspose);
39   param->src_format_ = static_cast<FormatC>((trans_info.src_format_));
40   param->dst_format_ = static_cast<FormatC>((trans_info.dst_format_));
41   kernel::KernelKey format_transpose_key = desc;
42   format_transpose_key.type = schema::PrimitiveType_FormatTranspose;
43   format_transpose_key.format = NHWC;
44   format_transpose_key.data_type = input->data_type();
45 
46   kernel::MSKernel *kernel_impl;
47   auto lite_kernel = KernelRegistry::GetInstance()->GetLiteKernel({input}, {output}, ctx, &format_transpose_key,
48                                                                   reinterpret_cast<OpParameter *>(param));
49   if (lite_kernel == nullptr) {
50     MS_LOG(ERROR) << "Create format-transpose lite-kernel failed.";
51     free(param);
52     return nullptr;
53   }
54   kernel_impl = lite_kernel;
55   auto *kernel_exec = new (std::nothrow) kernel::KernelExec(std::shared_ptr<kernel::MSKernel>(kernel_impl));
56   if (kernel_exec == nullptr) {
57     MS_LOG(ERROR) << "Create format-transpose kernel-exec failed.";
58     return nullptr;
59   }
60   kernel_exec->set_desc(format_transpose_key);
61   kernel_exec->set_context(ctx);
62   kernel_exec->set_name(name);
63   return kernel_exec;
64 }
65 }  // namespace
66 #endif
67 
AddPass(const FormatPassPtr & pass)68 int FormatOptimize::AddPass(const FormatPassPtr &pass) {
69   CHECK_NULL_RETURN(pass);
70   pass_list_.push_back(pass);
71   return RET_OK;
72 }
73 
RunPass(kernel::SubGraphKernel * graph,std::vector<Tensor * > * tensors)74 int FormatOptimize::RunPass(kernel::SubGraphKernel *graph, std::vector<Tensor *> *tensors) {
75   for (const FormatPassPtr &pass : pass_list_) {
76     CHECK_NULL_RETURN(pass);
77 
78     auto status = pass->RunPass(graph, tensors);
79     if (status != RET_OK) {
80       MS_LOG(ERROR) << "Run pass failed";
81       return status;
82     }
83     DrawDot(graph, pass->name());
84   }
85   return RET_OK;
86 }
87 
DoFormatPass(std::vector<mindspore::kernel::KernelExec * > * subgraph_list,std::vector<mindspore::lite::Tensor * > * tensors,mindspore::Format graph_format,const CreateFormatTransposeFunc & create_format_transpose_func)88 int DoFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list,
89                  std::vector<mindspore::lite::Tensor *> *tensors, mindspore::Format graph_format,
90                  const CreateFormatTransposeFunc &create_format_transpose_func) {
91   for (const auto &subgraph : *subgraph_list) {
92     FormatOptimizePtr optimize = std::make_shared<FormatOptimize>();
93 
94     (void)optimize->AddPass(std::make_shared<InsertTranspose>(graph_format, create_format_transpose_func));
95     (void)optimize->AddPass(std::make_shared<EliminateTranspose>(graph_format, create_format_transpose_func));
96 
97     auto graph = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
98     auto ret = optimize->RunPass(graph, tensors);
99     if (ret != RET_OK) {
100       MS_LOG(ERROR) << "Runtime format pass failed.";
101       return RET_ERROR;
102     }
103   }
104 
105   return RET_OK;
106 }
107 
RuntimeFormatPass(std::vector<mindspore::kernel::KernelExec * > * subgraph_list,std::vector<mindspore::lite::Tensor * > * tensors,mindspore::Format graph_format,const CreateFormatTransposeFunc & create_format_transpose_func)108 int RuntimeFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list,
109                       std::vector<mindspore::lite::Tensor *> *tensors, mindspore::Format graph_format,
110                       const CreateFormatTransposeFunc &create_format_transpose_func) {
111 #ifndef ENABLE_MULTI_LAYOUT
112   return RET_OK;
113 #else
114   if (create_format_transpose_func == nullptr) {
115     return DoFormatPass(subgraph_list, tensors, graph_format, DefaultCreateFormatTransFunc);
116   } else {
117     return DoFormatPass(subgraph_list, tensors, graph_format, create_format_transpose_func);
118   }
119 #endif
120 }
121 }  // namespace mindspore::lite::pass
122