1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_PASS_TRANSPOSE_STRATEGY_H_
17 #define MINDSPORE_LITE_SRC_RUNTIME_PASS_TRANSPOSE_STRATEGY_H_
18
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <functional>
23 #include <unordered_map>
24 #include "src/executor/kernel_exec.h"
25 #include "src/litert/pass/format_pass/pass_utils.h"
26
27 namespace mindspore::lite::pass {
28 static TransInfoPair kNHWC2NCHWTrans = {Format::NHWC, Format::NCHW};
29 static TransInfoPair kNCHW2NHWCTrans = {Format::NCHW, Format::NHWC};
30
31 template <typename T>
TransFormAxis(T axis,const TransInfoPair & trans)32 T TransFormAxis(T axis, const TransInfoPair &trans) {
33 if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
34 switch (axis) {
35 case kNHWC_N:
36 return kNCHW_N;
37 case kNHWC_H:
38 return kNCHW_H;
39 case kNHWC_W:
40 return kNCHW_W;
41 case kNHWC_C:
42 return kNCHW_C;
43 default:
44 return axis;
45 }
46 }
47 if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
48 switch (axis) {
49 case kNCHW_N:
50 return kNHWC_N;
51 case kNCHW_H:
52 return kNHWC_H;
53 case kNCHW_W:
54 return kNHWC_W;
55 case kNCHW_C:
56 return kNHWC_C;
57 default:
58 return axis;
59 }
60 }
61 return axis;
62 }
63
64 class TransposeStrategy {
65 public:
66 TransposeStrategy() = default;
67 ~TransposeStrategy() = default;
68
69 size_t GetTransCount(const std::vector<kernel::KernelExec *> &kernels, TransInfoPair *trans_info);
70 bool CrossKernelFusionPreCheck(const kernel::KernelExec *kernel, TransInfoPair *pre_trans, TransInfoPair *post_trans);
71 static int TryTransKernelAxis(kernel::KernelExec *kernel, const TransInfoPair &trans);
72 };
73 } // namespace mindspore::lite::pass
74 #endif // MINDSPORE_LITE_SRC_RUNTIME_PASS_TRANSPOSE_STRATEGY_H_
75