1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_PASS_FORMAT_PASS_UTILS_H_
17 #define MINDSPORE_LITE_SRC_RUNTIME_PASS_FORMAT_PASS_UTILS_H_
18
19 #include <string>
20 #include <vector>
21 #include <functional>
22 #include "src/executor/kernel_exec.h"
23 #include "src/executor/sub_graph_kernel.h"
24
25 namespace mindspore::lite::pass {
26 static const std::vector<int> nh2nc_perm = {0, 3, 1, 2};
27 static const std::vector<int> nc2nh_perm = {0, 2, 3, 1};
28 struct TransInfoPair {
29 mindspore::Format src_format_;
30 mindspore::Format dst_format_;
TransInfoPairTransInfoPair31 TransInfoPair() : src_format_(DEFAULT_FORMAT), dst_format_(DEFAULT_FORMAT) {}
TransInfoPairTransInfoPair32 TransInfoPair(Format src, Format dst) : src_format_(src), dst_format_(dst) {}
33 };
34
35 using CreateFormatTransposeFunc = std::function<kernel::KernelExec *(
36 InferTensor *input, InferTensor *output, const TransInfoPair &trans_info, const std::string &name,
37 const lite::InnerContext *ctx, const kernel::KernelKey &desc)>;
38
IsNCHWFormat(Format format)39 inline bool IsNCHWFormat(Format format) { return format == NCHW || format == NC4HW4 || format == NC8HW8; }
40
41 bool IsNoneTranspose(const TransInfoPair &trans);
42
43 bool IsSameTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1);
44
45 bool IsOppositiveTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1);
46
47 template <typename ShapeDT>
TransShape(const std::vector<ShapeDT> & shape,const TransInfoPair & trans,bool * ret)48 std::vector<ShapeDT> TransShape(const std::vector<ShapeDT> &shape, const TransInfoPair &trans, bool *ret) {
49 *ret = true;
50 if (shape.size() != DIMENSION_4D) {
51 return shape;
52 }
53 if (trans.src_format_ == trans.dst_format_ || (IsNCHWFormat(trans.src_format_) && IsNCHWFormat(trans.dst_format_))) {
54 return shape;
55 }
56 if (IsNCHWFormat(trans.src_format_) && trans.dst_format_ == NHWC) {
57 return {shape[0], shape[2], shape[3], shape[1]};
58 } else if (trans.src_format_ == NHWC && IsNCHWFormat(trans.dst_format_)) {
59 return {shape[0], shape[3], shape[1], shape[2]};
60 } else {
61 MS_LOG(WARNING) << "Unsupported transpose perm, from " << FormatEnumToString(trans.src_format_) << " to "
62 << FormatEnumToString(trans.dst_format_);
63 *ret = false;
64 return {};
65 }
66 }
67
68 bool TransTensorShapeAndFormat(Tensor *tensor, Format dst_format);
69 bool SetShape(const Tensor *src_tensor, Tensor *dst_tensor);
70 bool SetShape4D(const Tensor *src_tensor, Tensor *dst_tensor);
71
72 int InsertPreTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
73 const TransInfoPair &trans_info, const size_t &index, const CreateFormatTransposeFunc &func);
74
75 int InsertPostTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel,
76 std::vector<Tensor *> *all_tensors, const TransInfoPair &trans_info, const size_t &index,
77 const CreateFormatTransposeFunc &func);
78
79 int GetTransposeInfo(const kernel::KernelExec *kernel, TransInfoPair *trans_info);
80 } // namespace mindspore::lite::pass
81 #endif // MINDSPORE_LITE_SRC_RUNTIME_PASS_FORMAT_PASS_UTILS_H_
82