• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_PASS_TRANSPOSE_STRATEGY_H_
17 #define MINDSPORE_LITE_SRC_RUNTIME_PASS_TRANSPOSE_STRATEGY_H_
18 
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <functional>
23 #include <unordered_map>
24 #include "src/executor/kernel_exec.h"
25 #include "src/litert/pass/format_pass/pass_utils.h"
26 
27 namespace mindspore::lite::pass {
28 static TransInfoPair kNHWC2NCHWTrans = {Format::NHWC, Format::NCHW};
29 static TransInfoPair kNCHW2NHWCTrans = {Format::NCHW, Format::NHWC};
30 
31 template <typename T>
TransFormAxis(T axis,const TransInfoPair & trans)32 T TransFormAxis(T axis, const TransInfoPair &trans) {
33   if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
34     switch (axis) {
35       case kNHWC_N:
36         return kNCHW_N;
37       case kNHWC_H:
38         return kNCHW_H;
39       case kNHWC_W:
40         return kNCHW_W;
41       case kNHWC_C:
42         return kNCHW_C;
43       default:
44         return axis;
45     }
46   }
47   if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
48     switch (axis) {
49       case kNCHW_N:
50         return kNHWC_N;
51       case kNCHW_H:
52         return kNHWC_H;
53       case kNCHW_W:
54         return kNHWC_W;
55       case kNCHW_C:
56         return kNHWC_C;
57       default:
58         return axis;
59     }
60   }
61   return axis;
62 }
63 
64 class TransposeStrategy {
65  public:
66   TransposeStrategy() = default;
67   ~TransposeStrategy() = default;
68 
69   size_t GetTransCount(const std::vector<kernel::KernelExec *> &kernels, TransInfoPair *trans_info);
70   bool CrossKernelFusionPreCheck(const kernel::KernelExec *kernel, TransInfoPair *pre_trans, TransInfoPair *post_trans);
71   static int TryTransKernelAxis(kernel::KernelExec *kernel, const TransInfoPair &trans);
72 };
73 }  // namespace mindspore::lite::pass
74 #endif  // MINDSPORE_LITE_SRC_RUNTIME_PASS_TRANSPOSE_STRATEGY_H_
75