1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/pass/format_pass/pass_utils.h"
18 #include <string>
19 #include <vector>
20 #include "nnacl/format_transpose_parameter.h"
21 #include "nnacl/arg_min_max_parameter.h"
22
23 namespace mindspore::lite::pass {
IsNoneTranspose(const TransInfoPair & trans)24 bool IsNoneTranspose(const TransInfoPair &trans) {
25 return trans.src_format_ == Format::DEFAULT_FORMAT && trans.dst_format_ == Format::DEFAULT_FORMAT;
26 }
27
IsSameTranspose(const TransInfoPair & trans0,const TransInfoPair & trans1)28 bool IsSameTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1) {
29 if (!IsNoneTranspose(trans0) && !IsNoneTranspose(trans1)) {
30 return trans0.src_format_ == trans1.src_format_ && trans0.dst_format_ == trans1.dst_format_;
31 }
32 return false;
33 }
34
IsOppositiveTranspose(const TransInfoPair & trans0,const TransInfoPair & trans1)35 bool IsOppositiveTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1) {
36 if (!IsNoneTranspose(trans0) && IsNoneTranspose(trans1)) {
37 return true;
38 } else if (IsNoneTranspose(trans0) && !IsNoneTranspose(trans1)) {
39 return true;
40 } else if (!IsNoneTranspose(trans0) && !IsNoneTranspose(trans1)) {
41 return trans0.src_format_ == trans1.dst_format_ && trans0.dst_format_ == trans1.src_format_;
42 } else {
43 return false;
44 }
45 }
46
SetShape(const Tensor * src_tensor,Tensor * dst_tensor)47 bool SetShape(const Tensor *src_tensor, Tensor *dst_tensor) {
48 auto shape = src_tensor->shape();
49 if (shape.size() != DIMENSION_4D) {
50 dst_tensor->set_shape(shape);
51 return true;
52 }
53 if (std::any_of(shape.begin(), shape.end(), [](int dim) { return dim == -1; })) {
54 dst_tensor->set_shape({-1});
55 return true;
56 }
57 bool ret;
58 auto new_shape = TransShape(src_tensor->shape(), {src_tensor->format(), dst_tensor->format()}, &ret);
59 if (!ret) {
60 MS_LOG(ERROR) << "Transpose shape of tensor failed";
61 return false;
62 }
63 dst_tensor->set_shape(new_shape);
64 return true;
65 }
66
SetShape4D(const Tensor * src_tensor,Tensor * dst_tensor)67 bool SetShape4D(const Tensor *src_tensor, Tensor *dst_tensor) {
68 auto shape = src_tensor->shape();
69 auto invalid_shape = {-1};
70 if (shape.size() != DIMENSION_4D) {
71 dst_tensor->set_shape(invalid_shape);
72 return true;
73 }
74 return SetShape(src_tensor, dst_tensor);
75 }
76
TransTensorShapeAndFormat(Tensor * tensor,Format dst_format)77 bool TransTensorShapeAndFormat(Tensor *tensor, Format dst_format) {
78 auto shape = tensor->shape();
79 if (shape.size() != DIMENSION_4D) {
80 tensor->set_shape(shape);
81 return true;
82 }
83 bool ret;
84 auto new_shape = TransShape(tensor->shape(), {tensor->format(), dst_format}, &ret);
85 if (!ret) {
86 MS_LOG(ERROR) << "Transpose shape of tensor failed";
87 return false;
88 }
89 tensor->set_shape(new_shape);
90 tensor->set_format(dst_format);
91 return true;
92 }
93
InsertPreTranspose(kernel::SubGraphKernel * subgraph,kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & trans_info,const size_t & index,const CreateFormatTransposeFunc & func)94 int InsertPreTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
95 const TransInfoPair &trans_info, const size_t &index, const CreateFormatTransposeFunc &func) {
96 if (func == nullptr) {
97 MS_LOG(ERROR) << "CreateFormatTransposeFunc is nullptr.";
98 return RET_INPUT_PARAM_INVALID;
99 }
100 auto trans_name = kernel->name() + "_pre_" + std::to_string(index);
101 auto in_tensor = kernel->in_tensors().at(index);
102 auto in_tensor_shape = in_tensor->shape();
103 if (std::all_of(in_tensor_shape.begin(), in_tensor_shape.end(), [](const int &dim) { return dim >= 0; }) &&
104 in_tensor_shape.size() != DIMENSION_4D) {
105 MS_LOG(INFO) << index << "th input tensor of kernel " << kernel->name()
106 << " is infershaped and do not have 4 dimensions, skip inserting transpose kernel.";
107 return RET_OK;
108 }
109 auto out_tensor = new (std::nothrow) Tensor(in_tensor->data_type(), {}, (Format)trans_info.dst_format_);
110 CHECK_NULL_RETURN(out_tensor);
111 out_tensor->set_tensor_name(trans_name + "_output");
112 if (!SetShape4D(in_tensor, out_tensor)) {
113 MS_LOG(ERROR) << "Sync shape from in_tensor to out_tensor failed.";
114 delete out_tensor;
115 return RET_ERROR;
116 }
117
118 auto trans_kernel = func(in_tensor, out_tensor, trans_info, trans_name, kernel->Context(), kernel->desc());
119 if (trans_kernel == nullptr) {
120 delete out_tensor;
121 return RET_NULL_PTR;
122 }
123
124 all_tensors->push_back(out_tensor);
125 subgraph->InsertInEdge(kernel, trans_kernel, index);
126 return RET_OK;
127 }
128
InsertPostTranspose(kernel::SubGraphKernel * subgraph,kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & trans_info,const size_t & index,const CreateFormatTransposeFunc & func)129 int InsertPostTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel,
130 std::vector<Tensor *> *all_tensors, const TransInfoPair &trans_info, const size_t &index,
131 const CreateFormatTransposeFunc &func) {
132 if (func == nullptr) {
133 MS_LOG(ERROR) << "CreateFormatTransposeFunc is nullptr.";
134 return RET_INPUT_PARAM_INVALID;
135 }
136 auto trans_name = kernel->name() + "_post_" + std::to_string(index);
137
138 auto out_tensor = kernel->out_tensors().at(index);
139 auto out_tensor_shape = out_tensor->shape();
140 if (std::all_of(out_tensor_shape.begin(), out_tensor_shape.end(), [](const int &dim) { return dim >= 0; }) &&
141 out_tensor_shape.size() != DIMENSION_4D) {
142 MS_LOG(INFO) << index << "th output tensor of kernel " << kernel->name()
143 << " is infershaped and do not have 4 dimensions, skip inserting transpose kernel.";
144 return RET_OK;
145 }
146 auto in_tensor = new (std::nothrow) Tensor(out_tensor->data_type(), {}, (Format)trans_info.src_format_);
147 CHECK_NULL_RETURN(in_tensor);
148 in_tensor->set_tensor_name(trans_name + "_input");
149 if (!SetShape4D(out_tensor, in_tensor)) {
150 MS_LOG(ERROR) << "Sync shape from in_tensor to out_tensor failed.";
151 delete out_tensor;
152 return RET_ERROR;
153 }
154
155 auto trans_kernel = func(in_tensor, out_tensor, trans_info, trans_name, kernel->Context(), kernel->desc());
156 if (trans_kernel == nullptr) {
157 delete out_tensor;
158 return RET_NULL_PTR;
159 }
160
161 all_tensors->push_back(in_tensor);
162 subgraph->InsertOutEdge(kernel, trans_kernel, index);
163 return RET_OK;
164 }
165
GetTransposeInfo(const kernel::KernelExec * kernel,TransInfoPair * trans_info)166 int GetTransposeInfo(const kernel::KernelExec *kernel, TransInfoPair *trans_info) {
167 CHECK_NULL_RETURN(kernel);
168 if (kernel->type() != schema::PrimitiveType_Transpose && kernel->type() != schema::PrimitiveType_FormatTranspose) {
169 return RET_INVALID_OP_ATTR;
170 }
171 if (kernel->type() == schema::PrimitiveType_Transpose) {
172 CHECK_LESS_RETURN(kernel->in_tensors().size(), FormatTransposeInput);
173 auto perm_tensor = kernel->in_tensors().at(1);
174 CHECK_NULL_RETURN(perm_tensor);
175 if (perm_tensor->ElementsNum() != DIMENSION_4D || perm_tensor->data_type() != kNumberTypeInt32) {
176 return RET_INVALID_OP_ATTR;
177 }
178 auto perm_data = reinterpret_cast<int *>(perm_tensor->data());
179 CHECK_NULL_RETURN(perm_data);
180 std::vector<int> perm;
181 for (int i = 0; i < perm_tensor->ElementsNum(); i++) {
182 perm.push_back(perm_data[i]);
183 }
184 if (perm == nc2nh_perm) {
185 trans_info->src_format_ = NCHW;
186 trans_info->dst_format_ = NHWC;
187 } else if (perm == nh2nc_perm) {
188 trans_info->src_format_ = NHWC;
189 trans_info->dst_format_ = NCHW;
190 } else {
191 return RET_INVALID_OP_ATTR;
192 }
193 }
194 if (kernel->type() == schema::PrimitiveType_FormatTranspose) {
195 auto param = reinterpret_cast<FormatTransposeParameter *>(kernel->op_parameter());
196 CHECK_NULL_RETURN(param);
197 trans_info->src_format_ = static_cast<Format>((param->src_format_));
198 trans_info->dst_format_ = static_cast<Format>((param->dst_format_));
199 }
200 return RET_OK;
201 }
202 } // namespace mindspore::lite::pass
203