• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/pass/format_pass/transpose_strategy.h"
18 #include "nnacl/op_base.h"
19 #include "nnacl/arg_min_max_parameter.h"
20 #include "nnacl/concat_parameter.h"
21 #include "nnacl/crop_parameter.h"
22 #include "nnacl/softmax_parameter.h"
23 #include "nnacl/split_parameter.h"
24 #include "nnacl/squeeze_parameter.h"
25 #include "nnacl/stack_parameter.h"
26 #include "nnacl/unsqueeze_parameter.h"
27 #include "nnacl/unstack_parameter.h"
28 #include "nnacl/slice_parameter.h"
29 #include "nnacl/strided_slice_parameter.h"
30 
31 namespace mindspore::lite::pass {
32 static const std::set<schema::PrimitiveType> arithmetic_kernel_lists = {
33   schema::PrimitiveType_AddFusion,    schema::PrimitiveType_AddN,
34   schema::PrimitiveType_DivFusion,    schema::PrimitiveType_Eltwise,
35   schema::PrimitiveType_Equal,        schema::PrimitiveType_FloorDiv,
36   schema::PrimitiveType_FloorMod,     schema::PrimitiveType_Greater,
37   schema::PrimitiveType_GreaterEqual, schema::PrimitiveType_Less,
38   schema::PrimitiveType_LessEqual,    schema::PrimitiveType_LogicalAnd,
39   schema::PrimitiveType_LogicalOr,    schema::PrimitiveType_Maximum,
40   schema::PrimitiveType_Minimum,      schema::PrimitiveType_Mod,
41   schema::PrimitiveType_MulFusion,    schema::PrimitiveType_NotEqual,
42   schema::PrimitiveType_RealDiv,      schema::PrimitiveType_SquaredDifference,
43   schema::PrimitiveType_SubFusion,
44 };
45 
GetTransCount(const std::vector<kernel::KernelExec * > & kernels,TransInfoPair * trans_info)46 size_t TransposeStrategy::GetTransCount(const std::vector<kernel::KernelExec *> &kernels, TransInfoPair *trans_info) {
47   size_t count = 0;
48   for (const auto &in_kernel : kernels) {
49     TransInfoPair tmp_trans;
50     if (GetTransposeInfo(in_kernel, &tmp_trans) != RET_OK) {
51       continue;
52     }
53     if (IsNoneTranspose(*trans_info)) {
54       *trans_info = tmp_trans;
55       count++;
56     } else if (IsSameTranspose(*trans_info, tmp_trans)) {
57       count++;
58     } else {
59       continue;
60     }
61   }
62   return count;
63 }
64 
CheckInTensorsShape(const kernel::KernelExec * kernel,const Format & runtime_format)65 bool CheckInTensorsShape(const kernel::KernelExec *kernel, const Format &runtime_format) {
66   // If teh fusion is valid, kernel will be executed in runtime_format.
67   // Only check arithmetic (two input) kernel input tensors.
68   // If broadcast for various formats is supported, this function can be deleted.
69   // eg: tensor 1 shape(1, 128, 24, 24), tensor 2 shape(1, 128, 1, 1), the NC4HW4 format is not supported now.
70   if (arithmetic_kernel_lists.find(kernel::SchemaType(kernel->type())) == arithmetic_kernel_lists.end()) {
71     return true;
72   }
73   for (const auto &in_tensor : kernel->in_tensors()) {
74     const auto &in_shape = in_tensor->shape();
75     if (std::any_of(in_shape.begin(), in_shape.end(), [](const int &dim) { return dim == -1; })) {
76       return false;
77     }
78   }
79   const auto &in0_shape = kernel->in_tensors().at(0)->shape();
80   if (runtime_format == NHWC || runtime_format == NCHW) {
81     // For NCHW or NHWC format, the shape.size must be equal.
82     if (std::any_of(kernel->in_tensors().begin(), kernel->in_tensors().end(),
83                     [&in0_shape](const Tensor *in_tensor) { return in_tensor->shape().size() != in0_shape.size(); })) {
84       return false;
85     }
86   } else {
87     // For other format(NCXHWX), the shape must be equal.
88     if (std::any_of(kernel->in_tensors().begin(), kernel->in_tensors().end(),
89                     [&in0_shape](const Tensor *in_tensor) { return in_tensor->shape() != in0_shape; })) {
90       return false;
91     }
92   }
93   return true;
94 }
95 
96 namespace {
97 using TransAxisFunc = std::function<int(kernel::KernelExec *, const TransInfoPair &)>;
__anond23b4b690502(kernel::KernelExec *, const TransInfoPair &) 98 TransAxisFunc kNoNeedTransAxisFunc = [](kernel::KernelExec *, const TransInfoPair &) { return RET_OK; };
__anond23b4b690602(kernel::KernelExec *, const TransInfoPair &) 99 TransAxisFunc kNotImplementedTransAxisFunc = [](kernel::KernelExec *, const TransInfoPair &) { return RET_ERROR; };
HandleArgMinMaxKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)100 int HandleArgMinMaxKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
101   auto arg_min_max_param = reinterpret_cast<ArgMinMaxParameter *>(kernel->op_parameter());
102   CHECK_NULL_RETURN(arg_min_max_param);
103   arg_min_max_param->axis_ = TransFormAxis(arg_min_max_param->axis_, trans);
104   return RET_OK;
105 }
106 
HandleSoftMaxKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)107 int HandleSoftMaxKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
108   // nnacl need transpose op_parameter but BaseOperator beed transpose Primitive
109   auto param = reinterpret_cast<SoftmaxParameter *>(kernel->op_parameter());
110   CHECK_NULL_RETURN(param);
111   param->axis_ = TransFormAxis(param->axis_, trans);
112   return RET_OK;
113 }
114 
HandleSplitKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)115 int HandleSplitKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
116   auto param = reinterpret_cast<SplitParameter *>(kernel->op_parameter());
117   CHECK_NULL_RETURN(param);
118   param->split_dim_ = TransFormAxis(param->split_dim_, trans);
119   return RET_OK;
120 }
121 
HandleSqueezeKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)122 int HandleSqueezeKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
123   auto param = reinterpret_cast<SqueezeParameter *>(kernel->op_parameter());
124   CHECK_NULL_RETURN(param);
125   for (size_t i = 0; i < param->axis_size_; i++) {
126     param->axis_[i] = TransFormAxis(param->axis_[i], trans);
127   }
128   return RET_OK;
129 }
130 
HandleUnSqueezeKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)131 int HandleUnSqueezeKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
132   auto param = reinterpret_cast<UnSqueezeParameter *>(kernel->op_parameter());
133   CHECK_NULL_RETURN(param);
134   param->axis_ = TransFormAxis(param->axis_, trans);
135   return RET_OK;
136 }
137 
HandleStackKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)138 int HandleStackKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
139   auto param = reinterpret_cast<StackParameter *>(kernel->op_parameter());
140   CHECK_NULL_RETURN(param);
141   param->axis_ = TransFormAxis(param->axis_, trans);
142   return RET_OK;
143 }
144 
HandleUnStackKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)145 int HandleUnStackKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
146   auto param = reinterpret_cast<UnstackParameter *>(kernel->op_parameter());
147   CHECK_NULL_RETURN(param);
148   param->axis_ = TransFormAxis(param->axis_, trans);
149   return RET_OK;
150 }
151 
HandleConcatKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)152 int HandleConcatKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
153   auto concat_param = reinterpret_cast<ConcatParameter *>(kernel->op_parameter());
154   CHECK_NULL_RETURN(concat_param);
155   concat_param->axis_ = TransFormAxis(concat_param->axis_, trans);
156   return RET_OK;
157 }
158 
159 namespace {
Handle0AxisCrop(const TransInfoPair & trans,CropParameter * crop_param)160 int Handle0AxisCrop(const TransInfoPair &trans, CropParameter *crop_param) {
161   auto offset = crop_param->offset_;
162   if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
163     crop_param->offset_[kNHWC_N] = offset[kNCHW_N];
164     crop_param->offset_[kNHWC_H] = offset[kNCHW_H];
165     crop_param->offset_[kNHWC_W] = offset[kNCHW_W];
166     crop_param->offset_[kNHWC_C] = offset[kNCHW_C];
167     return RET_OK;
168   }
169   if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
170     crop_param->offset_[kNCHW_N] = offset[kNHWC_N];
171     crop_param->offset_[kNCHW_H] = offset[kNHWC_H];
172     crop_param->offset_[kNCHW_W] = offset[kNHWC_W];
173     crop_param->offset_[kNCHW_C] = offset[kNHWC_C];
174     return RET_OK;
175   }
176   MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
177   return RET_ERROR;
178 }
179 
Handle1AxisCrop(const TransInfoPair & trans,CropParameter * crop_param)180 int Handle1AxisCrop(const TransInfoPair &trans, CropParameter *crop_param) {
181   auto offset = crop_param->offset_;
182   if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
183     crop_param->offset_[kNHWC_H - 1] = offset[kNCHW_H - 1];
184     crop_param->offset_[kNHWC_W - 1] = offset[kNCHW_W - 1];
185     crop_param->offset_[kNHWC_C - 1] = offset[kNCHW_C - 1];
186     return RET_OK;
187   }
188   if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
189     crop_param->offset_[kNCHW_H - 1] = offset[kNHWC_H - 1];
190     crop_param->offset_[kNCHW_W - 1] = offset[kNHWC_W - 1];
191     crop_param->offset_[kNCHW_C - 1] = offset[kNHWC_C - 1];
192     return RET_OK;
193   }
194   MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
195   return RET_ERROR;
196 }
197 
Handle2AxisCrop(const TransInfoPair & trans,CropParameter * crop_param,const lite::Tensor & input_tensor,const lite::Tensor & shape_tensor)198 int Handle2AxisCrop(const TransInfoPair &trans, CropParameter *crop_param, const lite::Tensor &input_tensor,
199                     const lite::Tensor &shape_tensor) {
200   if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
201     if (input_tensor.Channel() != shape_tensor.Channel()) {
202       return RET_NO_CHANGE;
203     }
204     crop_param->offset_[4 - crop_param->axis_] = 0;
205     crop_param->axis_ = crop_param->axis_ - 1;
206     crop_param->offset_size_ = crop_param->offset_size_ + 1;
207     return RET_OK;
208   }
209   auto offset1 = crop_param->offset_[1];
210   auto offset0 = crop_param->offset_[0];
211   if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
212     if (input_tensor.Height() != shape_tensor.Height()) {
213       return RET_NO_CHANGE;
214     }
215     crop_param->axis_ = 1;
216     crop_param->offset_size_ = 3;
217     crop_param->offset_[0] = offset1;
218     crop_param->offset_[1] = 0;
219     crop_param->offset_[2] = offset0;
220     return RET_OK;
221   }
222   MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
223   return RET_ERROR;
224 }
225 
Handle3AxisCrop(const TransInfoPair & trans,CropParameter * crop_param,const lite::Tensor & input_tensor,const lite::Tensor & shape_tensor)226 int Handle3AxisCrop(const TransInfoPair &trans, CropParameter *crop_param, const lite::Tensor &input_tensor,
227                     const lite::Tensor &shape_tensor) {
228   if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
229     if (input_tensor.Channel() != shape_tensor.Channel()) {
230       return RET_NO_CHANGE;
231     }
232     crop_param->offset_[4 - crop_param->axis_] = 0;
233     crop_param->axis_ = crop_param->axis_ - 1;
234     crop_param->offset_size_ = crop_param->offset_size_ + 1;
235     return RET_OK;
236   }
237   auto offset0 = crop_param->offset_[0];
238   if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
239     if (input_tensor.Height() != shape_tensor.Height() || input_tensor.Width() != shape_tensor.Width()) {
240       return RET_NO_CHANGE;
241     }
242     crop_param->axis_ = 1;
243     crop_param->offset_size_ = 3;
244     crop_param->offset_[0] = offset0;
245     crop_param->offset_[1] = 0;
246     crop_param->offset_[2] = 0;
247     return RET_OK;
248   }
249   MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
250   return RET_ERROR;
251 }
252 }  // namespace
253 
HandleCropKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)254 int HandleCropKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
255   auto crop_param = reinterpret_cast<CropParameter *>(kernel->op_parameter());
256   CHECK_NULL_RETURN(crop_param);
257   auto inputs = kernel->in_tensors();
258   for (const auto &input : inputs) {
259     auto shape = input->shape();
260     if (shape.size() != DIMENSION_4D) {
261       return RET_NO_CHANGE;
262     }
263     if (std::any_of(shape.begin(), shape.end(), [](const int &dim) { return dim < 0; })) {
264       return RET_NO_CHANGE;
265     }
266   }
267   if (crop_param->axis_ == 0) {
268     return Handle0AxisCrop(trans, crop_param);
269   }
270   if (crop_param->axis_ == 1) {
271     return Handle1AxisCrop(trans, crop_param);
272   }
273   if (crop_param->axis_ == 2) {
274     return Handle2AxisCrop(trans, crop_param, *inputs[0], *inputs[1]);
275   }
276   if (crop_param->axis_ == 3) {
277     return Handle3AxisCrop(trans, crop_param, *inputs[0], *inputs[1]);
278   }
279   MS_LOG(ERROR) << "axis of parameter of Crop out of range, input dimension: 4, axis: " << crop_param->axis_;
280   return RET_ERROR;
281 }
282 
283 // bool value determines whether the kernel has axis attribute or not.
284 // If bool value is true, the single kernel can be processd only for NHWC2NCHW or NCHW2NHWC.
285 static const std::unordered_map<schema::PrimitiveType, TransAxisFunc> kTransAxisFuncs = {
286   {schema::PrimitiveType_Abs, kNoNeedTransAxisFunc},
287   {schema::PrimitiveType_Activation, kNoNeedTransAxisFunc},
288   {schema::PrimitiveType_AddFusion, kNoNeedTransAxisFunc},
289   {schema::PrimitiveType_AddN, kNoNeedTransAxisFunc},
290   {schema::PrimitiveType_ArgMaxFusion, HandleArgMinMaxKernel},
291   {schema::PrimitiveType_ArgMinFusion, HandleArgMinMaxKernel},
292   {schema::PrimitiveType_Cast, kNoNeedTransAxisFunc},
293   {schema::PrimitiveType_Ceil, kNoNeedTransAxisFunc},
294   {schema::PrimitiveType_Clip, kNoNeedTransAxisFunc},
295   {schema::PrimitiveType_Concat, HandleConcatKernel},
296   {schema::PrimitiveType_Cos, kNoNeedTransAxisFunc},
297   {schema::PrimitiveType_Crop, HandleCropKernel},
298   {schema::PrimitiveType_DivFusion, kNoNeedTransAxisFunc},
299   {schema::PrimitiveType_Elu, kNoNeedTransAxisFunc},
300   {schema::PrimitiveType_Eltwise, kNoNeedTransAxisFunc},
301   {schema::PrimitiveType_Equal, kNoNeedTransAxisFunc},
302   {schema::PrimitiveType_ExpFusion, kNoNeedTransAxisFunc},
303   {schema::PrimitiveType_Floor, kNoNeedTransAxisFunc},
304   {schema::PrimitiveType_FloorDiv, kNoNeedTransAxisFunc},
305   {schema::PrimitiveType_FloorMod, kNoNeedTransAxisFunc},
306   {schema::PrimitiveType_Greater, kNoNeedTransAxisFunc},
307   {schema::PrimitiveType_GreaterEqual, kNoNeedTransAxisFunc},
308   {schema::PrimitiveType_Less, kNoNeedTransAxisFunc},
309   {schema::PrimitiveType_LessEqual, kNoNeedTransAxisFunc},
310   {schema::PrimitiveType_Log, kNoNeedTransAxisFunc},
311   {schema::PrimitiveType_LogicalAnd, kNoNeedTransAxisFunc},
312   {schema::PrimitiveType_LogicalNot, kNoNeedTransAxisFunc},
313   {schema::PrimitiveType_LogicalOr, kNoNeedTransAxisFunc},
314   {schema::PrimitiveType_Maximum, kNoNeedTransAxisFunc},
315   {schema::PrimitiveType_Minimum, kNoNeedTransAxisFunc},
316   {schema::PrimitiveType_Mod, kNoNeedTransAxisFunc},
317   {schema::PrimitiveType_MulFusion, kNoNeedTransAxisFunc},
318   {schema::PrimitiveType_Neg, kNoNeedTransAxisFunc},
319   {schema::PrimitiveType_NotEqual, kNoNeedTransAxisFunc},
320   {schema::PrimitiveType_PowFusion, kNoNeedTransAxisFunc},
321   {schema::PrimitiveType_QuantDTypeCast, kNoNeedTransAxisFunc},
322   {schema::PrimitiveType_RealDiv, kNoNeedTransAxisFunc},
323   {schema::PrimitiveType_Round, kNoNeedTransAxisFunc},
324   {schema::PrimitiveType_Rsqrt, kNoNeedTransAxisFunc},
325   {schema::PrimitiveType_Sin, kNoNeedTransAxisFunc},
326   {schema::PrimitiveType_SliceFusion, kNotImplementedTransAxisFunc},
327   {schema::PrimitiveType_Softmax, HandleSoftMaxKernel},
328   {schema::PrimitiveType_Split, HandleSplitKernel},
329   {schema::PrimitiveType_Sqrt, kNoNeedTransAxisFunc},
330   {schema::PrimitiveType_Squeeze, HandleSqueezeKernel},
331   {schema::PrimitiveType_Square, kNoNeedTransAxisFunc},
332   {schema::PrimitiveType_SquaredDifference, kNoNeedTransAxisFunc},
333   {schema::PrimitiveType_Stack, HandleStackKernel},
334   {schema::PrimitiveType_StridedSlice, kNotImplementedTransAxisFunc},
335   {schema::PrimitiveType_SubFusion, kNoNeedTransAxisFunc},
336   {schema::PrimitiveType_Unsqueeze, HandleUnSqueezeKernel},
337   {schema::PrimitiveType_Unstack, HandleUnStackKernel},
338   {schema::PrimitiveType_LogSoftmax, HandleSoftMaxKernel},
339   {schema::PrimitiveType_Erf, kNoNeedTransAxisFunc},
340 };
341 }  // namespace
342 
CrossKernelFusionPreCheck(const kernel::KernelExec * kernel,TransInfoPair * pre_trans,TransInfoPair * post_trans)343 bool TransposeStrategy::CrossKernelFusionPreCheck(const kernel::KernelExec *kernel, TransInfoPair *pre_trans,
344                                                   TransInfoPair *post_trans) {
345   if (kTransAxisFuncs.find(kernel::SchemaType(kernel->type())) == kTransAxisFuncs.end()) {
346     return false;
347   }
348   auto input_count = GetTransCount(kernel->in_kernels(), pre_trans);
349   auto output_count = GetTransCount(kernel->out_kernels(), post_trans);
350   if (IsSameTranspose(*pre_trans, *post_trans)) {
351     return false;
352   }
353   if (!IsOppositiveTranspose(*pre_trans, *post_trans)) {
354     return false;
355   }
356   auto in_and_out_size = kernel->in_tensors().size() + kernel->out_kernels().size();
357   if ((input_count + output_count) <= in_and_out_size / C2NUM) {
358     MS_LOG(DEBUG) << "The fusion can't decrease transpose op number.";
359     return false;
360   }
361   if (IsNoneTranspose(*pre_trans)) {
362     pre_trans->src_format_ = post_trans->dst_format_;
363     pre_trans->dst_format_ = post_trans->src_format_;
364   }
365   if (IsNoneTranspose(*post_trans)) {
366     post_trans->src_format_ = pre_trans->dst_format_;
367     post_trans->dst_format_ = pre_trans->src_format_;
368   }
369 
370   if (((!IsSameTranspose(*post_trans, kNCHW2NHWCTrans)) && (!IsSameTranspose(*post_trans, kNHWC2NCHWTrans))) &&
371       kTransAxisFuncs.at(kernel::SchemaType(kernel->type()))) {
372     return false;
373   }
374   if (!CheckInTensorsShape(kernel, (Format)(post_trans->dst_format_))) {
375     return false;
376   }
377   return true;
378 }
379 
TryTransKernelAxis(kernel::KernelExec * kernel,const TransInfoPair & trans)380 int TransposeStrategy::TryTransKernelAxis(kernel::KernelExec *kernel, const TransInfoPair &trans) {
381   auto trans_axis_func = kTransAxisFuncs.find(kernel::SchemaType(kernel->type()));
382   if (trans_axis_func == kTransAxisFuncs.end() || trans_axis_func->second == nullptr) {
383     MS_LOG(ERROR) << "Can't find the axis change function for " << kernel->name();
384     return RET_ERROR;
385   }
386   return trans_axis_func->second(kernel, trans);
387 }
388 }  // namespace mindspore::lite::pass
389