• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/pass/format_pass/pass_utils.h"
18 #include <string>
19 #include <vector>
20 #include "nnacl/format_transpose_parameter.h"
21 #include "nnacl/arg_min_max_parameter.h"
22 
23 namespace mindspore::lite::pass {
IsNoneTranspose(const TransInfoPair & trans)24 bool IsNoneTranspose(const TransInfoPair &trans) {
25   return trans.src_format_ == Format::DEFAULT_FORMAT && trans.dst_format_ == Format::DEFAULT_FORMAT;
26 }
27 
IsSameTranspose(const TransInfoPair & trans0,const TransInfoPair & trans1)28 bool IsSameTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1) {
29   if (!IsNoneTranspose(trans0) && !IsNoneTranspose(trans1)) {
30     return trans0.src_format_ == trans1.src_format_ && trans0.dst_format_ == trans1.dst_format_;
31   }
32   return false;
33 }
34 
IsOppositiveTranspose(const TransInfoPair & trans0,const TransInfoPair & trans1)35 bool IsOppositiveTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1) {
36   if (!IsNoneTranspose(trans0) && IsNoneTranspose(trans1)) {
37     return true;
38   } else if (IsNoneTranspose(trans0) && !IsNoneTranspose(trans1)) {
39     return true;
40   } else if (!IsNoneTranspose(trans0) && !IsNoneTranspose(trans1)) {
41     return trans0.src_format_ == trans1.dst_format_ && trans0.dst_format_ == trans1.src_format_;
42   } else {
43     return false;
44   }
45 }
46 
SetShape(const Tensor * src_tensor,Tensor * dst_tensor)47 bool SetShape(const Tensor *src_tensor, Tensor *dst_tensor) {
48   auto shape = src_tensor->shape();
49   if (shape.size() != DIMENSION_4D) {
50     dst_tensor->set_shape(shape);
51     return true;
52   }
53   if (std::any_of(shape.begin(), shape.end(), [](int dim) { return dim == -1; })) {
54     dst_tensor->set_shape({-1});
55     return true;
56   }
57   bool ret;
58   auto new_shape = TransShape(src_tensor->shape(), {src_tensor->format(), dst_tensor->format()}, &ret);
59   if (!ret) {
60     MS_LOG(ERROR) << "Transpose shape of tensor failed";
61     return false;
62   }
63   dst_tensor->set_shape(new_shape);
64   return true;
65 }
66 
SetShape4D(const Tensor * src_tensor,Tensor * dst_tensor)67 bool SetShape4D(const Tensor *src_tensor, Tensor *dst_tensor) {
68   auto shape = src_tensor->shape();
69   auto invalid_shape = {-1};
70   if (shape.size() != DIMENSION_4D) {
71     dst_tensor->set_shape(invalid_shape);
72     return true;
73   }
74   return SetShape(src_tensor, dst_tensor);
75 }
76 
TransTensorShapeAndFormat(Tensor * tensor,Format dst_format)77 bool TransTensorShapeAndFormat(Tensor *tensor, Format dst_format) {
78   auto shape = tensor->shape();
79   if (shape.size() != DIMENSION_4D) {
80     tensor->set_shape(shape);
81     return true;
82   }
83   bool ret;
84   auto new_shape = TransShape(tensor->shape(), {tensor->format(), dst_format}, &ret);
85   if (!ret) {
86     MS_LOG(ERROR) << "Transpose shape of tensor failed";
87     return false;
88   }
89   tensor->set_shape(new_shape);
90   tensor->set_format(dst_format);
91   return true;
92 }
93 
InsertPreTranspose(kernel::SubGraphKernel * subgraph,kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & trans_info,const size_t & index,const CreateFormatTransposeFunc & func)94 int InsertPreTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
95                        const TransInfoPair &trans_info, const size_t &index, const CreateFormatTransposeFunc &func) {
96   if (func == nullptr) {
97     MS_LOG(ERROR) << "CreateFormatTransposeFunc is nullptr.";
98     return RET_INPUT_PARAM_INVALID;
99   }
100   auto trans_name = kernel->name() + "_pre_" + std::to_string(index);
101   auto in_tensor = kernel->in_tensors().at(index);
102   auto in_tensor_shape = in_tensor->shape();
103   if (std::all_of(in_tensor_shape.begin(), in_tensor_shape.end(), [](const int &dim) { return dim >= 0; }) &&
104       in_tensor_shape.size() != DIMENSION_4D) {
105     MS_LOG(INFO) << index << "th input tensor of kernel " << kernel->name()
106                  << " is infershaped and do not have 4 dimensions, skip inserting transpose kernel.";
107     return RET_OK;
108   }
109   auto out_tensor = new (std::nothrow) Tensor(in_tensor->data_type(), {}, (Format)trans_info.dst_format_);
110   CHECK_NULL_RETURN(out_tensor);
111   out_tensor->set_tensor_name(trans_name + "_output");
112   if (!SetShape4D(in_tensor, out_tensor)) {
113     MS_LOG(ERROR) << "Sync shape from in_tensor to out_tensor failed.";
114     delete out_tensor;
115     return RET_ERROR;
116   }
117 
118   auto trans_kernel = func(in_tensor, out_tensor, trans_info, trans_name, kernel->Context(), kernel->desc());
119   if (trans_kernel == nullptr) {
120     delete out_tensor;
121     return RET_NULL_PTR;
122   }
123 
124   all_tensors->push_back(out_tensor);
125   subgraph->InsertInEdge(kernel, trans_kernel, index);
126   return RET_OK;
127 }
128 
InsertPostTranspose(kernel::SubGraphKernel * subgraph,kernel::KernelExec * kernel,std::vector<Tensor * > * all_tensors,const TransInfoPair & trans_info,const size_t & index,const CreateFormatTransposeFunc & func)129 int InsertPostTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel,
130                         std::vector<Tensor *> *all_tensors, const TransInfoPair &trans_info, const size_t &index,
131                         const CreateFormatTransposeFunc &func) {
132   if (func == nullptr) {
133     MS_LOG(ERROR) << "CreateFormatTransposeFunc is nullptr.";
134     return RET_INPUT_PARAM_INVALID;
135   }
136   auto trans_name = kernel->name() + "_post_" + std::to_string(index);
137 
138   auto out_tensor = kernel->out_tensors().at(index);
139   auto out_tensor_shape = out_tensor->shape();
140   if (std::all_of(out_tensor_shape.begin(), out_tensor_shape.end(), [](const int &dim) { return dim >= 0; }) &&
141       out_tensor_shape.size() != DIMENSION_4D) {
142     MS_LOG(INFO) << index << "th output tensor of kernel " << kernel->name()
143                  << " is infershaped and do not have 4 dimensions, skip inserting transpose kernel.";
144     return RET_OK;
145   }
146   auto in_tensor = new (std::nothrow) Tensor(out_tensor->data_type(), {}, (Format)trans_info.src_format_);
147   CHECK_NULL_RETURN(in_tensor);
148   in_tensor->set_tensor_name(trans_name + "_input");
149   if (!SetShape4D(out_tensor, in_tensor)) {
150     MS_LOG(ERROR) << "Sync shape from in_tensor to out_tensor failed.";
151     delete out_tensor;
152     return RET_ERROR;
153   }
154 
155   auto trans_kernel = func(in_tensor, out_tensor, trans_info, trans_name, kernel->Context(), kernel->desc());
156   if (trans_kernel == nullptr) {
157     delete out_tensor;
158     return RET_NULL_PTR;
159   }
160 
161   all_tensors->push_back(in_tensor);
162   subgraph->InsertOutEdge(kernel, trans_kernel, index);
163   return RET_OK;
164 }
165 
GetTransposeInfo(const kernel::KernelExec * kernel,TransInfoPair * trans_info)166 int GetTransposeInfo(const kernel::KernelExec *kernel, TransInfoPair *trans_info) {
167   CHECK_NULL_RETURN(kernel);
168   if (kernel->type() != schema::PrimitiveType_Transpose && kernel->type() != schema::PrimitiveType_FormatTranspose) {
169     return RET_INVALID_OP_ATTR;
170   }
171   if (kernel->type() == schema::PrimitiveType_Transpose) {
172     CHECK_LESS_RETURN(kernel->in_tensors().size(), FormatTransposeInput);
173     auto perm_tensor = kernel->in_tensors().at(1);
174     CHECK_NULL_RETURN(perm_tensor);
175     if (perm_tensor->ElementsNum() != DIMENSION_4D || perm_tensor->data_type() != kNumberTypeInt32) {
176       return RET_INVALID_OP_ATTR;
177     }
178     auto perm_data = reinterpret_cast<int *>(perm_tensor->data());
179     CHECK_NULL_RETURN(perm_data);
180     std::vector<int> perm;
181     for (int i = 0; i < perm_tensor->ElementsNum(); i++) {
182       perm.push_back(perm_data[i]);
183     }
184     if (perm == nc2nh_perm) {
185       trans_info->src_format_ = NCHW;
186       trans_info->dst_format_ = NHWC;
187     } else if (perm == nh2nc_perm) {
188       trans_info->src_format_ = NHWC;
189       trans_info->dst_format_ = NCHW;
190     } else {
191       return RET_INVALID_OP_ATTR;
192     }
193   }
194   if (kernel->type() == schema::PrimitiveType_FormatTranspose) {
195     auto param = reinterpret_cast<FormatTransposeParameter *>(kernel->op_parameter());
196     CHECK_NULL_RETURN(param);
197     trans_info->src_format_ = static_cast<Format>((param->src_format_));
198     trans_info->dst_format_ = static_cast<Format>((param->dst_format_));
199   }
200   return RET_OK;
201 }
202 }  // namespace mindspore::lite::pass
203