• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_PASS_FORMAT_PASS_UTILS_H_
17 #define MINDSPORE_LITE_SRC_RUNTIME_PASS_FORMAT_PASS_UTILS_H_
18 
19 #include <string>
20 #include <vector>
21 #include <functional>
22 #include "src/executor/kernel_exec.h"
23 #include "src/executor/sub_graph_kernel.h"
24 
25 namespace mindspore::lite::pass {
26 static const std::vector<int> nh2nc_perm = {0, 3, 1, 2};
27 static const std::vector<int> nc2nh_perm = {0, 2, 3, 1};
28 struct TransInfoPair {
29   mindspore::Format src_format_;
30   mindspore::Format dst_format_;
TransInfoPairTransInfoPair31   TransInfoPair() : src_format_(DEFAULT_FORMAT), dst_format_(DEFAULT_FORMAT) {}
TransInfoPairTransInfoPair32   TransInfoPair(Format src, Format dst) : src_format_(src), dst_format_(dst) {}
33 };
34 
35 using CreateFormatTransposeFunc = std::function<kernel::KernelExec *(
36   InferTensor *input, InferTensor *output, const TransInfoPair &trans_info, const std::string &name,
37   const lite::InnerContext *ctx, const kernel::KernelKey &desc)>;
38 
IsNCHWFormat(Format format)39 inline bool IsNCHWFormat(Format format) { return format == NCHW || format == NC4HW4 || format == NC8HW8; }
40 
41 bool IsNoneTranspose(const TransInfoPair &trans);
42 
43 bool IsSameTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1);
44 
45 bool IsOppositiveTranspose(const TransInfoPair &trans0, const TransInfoPair &trans1);
46 
47 template <typename ShapeDT>
TransShape(const std::vector<ShapeDT> & shape,const TransInfoPair & trans,bool * ret)48 std::vector<ShapeDT> TransShape(const std::vector<ShapeDT> &shape, const TransInfoPair &trans, bool *ret) {
49   *ret = true;
50   if (shape.size() != DIMENSION_4D) {
51     return shape;
52   }
53   if (trans.src_format_ == trans.dst_format_ || (IsNCHWFormat(trans.src_format_) && IsNCHWFormat(trans.dst_format_))) {
54     return shape;
55   }
56   if (IsNCHWFormat(trans.src_format_) && trans.dst_format_ == NHWC) {
57     return {shape[0], shape[2], shape[3], shape[1]};
58   } else if (trans.src_format_ == NHWC && IsNCHWFormat(trans.dst_format_)) {
59     return {shape[0], shape[3], shape[1], shape[2]};
60   } else {
61     MS_LOG(WARNING) << "Unsupported transpose perm, from " << FormatEnumToString(trans.src_format_) << " to "
62                     << FormatEnumToString(trans.dst_format_);
63     *ret = false;
64     return {};
65   }
66 }
67 
68 bool TransTensorShapeAndFormat(Tensor *tensor, Format dst_format);
69 bool SetShape(const Tensor *src_tensor, Tensor *dst_tensor);
70 bool SetShape4D(const Tensor *src_tensor, Tensor *dst_tensor);
71 
72 int InsertPreTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel, std::vector<Tensor *> *all_tensors,
73                        const TransInfoPair &trans_info, const size_t &index, const CreateFormatTransposeFunc &func);
74 
75 int InsertPostTranspose(kernel::SubGraphKernel *subgraph, kernel::KernelExec *kernel,
76                         std::vector<Tensor *> *all_tensors, const TransInfoPair &trans_info, const size_t &index,
77                         const CreateFormatTransposeFunc &func);
78 
79 int GetTransposeInfo(const kernel::KernelExec *kernel, TransInfoPair *trans_info);
80 }  // namespace mindspore::lite::pass
81 #endif  // MINDSPORE_LITE_SRC_RUNTIME_PASS_FORMAT_PASS_UTILS_H_
82