1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_DATA_TRANSPOSE_UTILS_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_DATA_TRANSPOSE_UTILS_H_
19
20 #include <vector>
21 #include "mindapi/base/format.h"
22 #include "mindapi/ir/tensor.h"
23 #include "mindapi/base/logging.h"
24 #include "include/errorcode.h"
25 #include "common/op_enum.h"
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_NO_CHANGE;
28 using mindspore::lite::RET_OK;
29 using mindspore::lite::STATUS;
30 namespace mindspore {
31 namespace dpico {
32 inline const std::vector<int> kNH2NC = {0, 3, 1, 2};
33 inline const std::vector<int> kNC2NH = {0, 2, 3, 1};
34 enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
35 struct TransTypePair {
36 FormatTransNodeType pre_;
37 FormatTransNodeType post_;
TransTypePairTransTypePair38 TransTypePair() : pre_(kNONE), post_(kNONE) {}
39 };
40 template <typename T>
NHWC2NCHW(T * src_data,T * dst_data,std::vector<int32_t> shape)41 STATUS NHWC2NCHW(T *src_data, T *dst_data, std::vector<int32_t> shape) {
42 if (shape.size() != kDims4) {
43 MS_LOG(ERROR) << "The dim should be 4.";
44 return RET_ERROR;
45 }
46 int32_t batch = shape.at(0);
47 int32_t plane = shape.at(kAxis1) * shape.at(kAxis2);
48 int32_t channel = shape.at(kAxis3);
49 for (int32_t b = 0; b < batch; b++) {
50 for (int32_t p = 0; p < plane; p++) {
51 for (int32_t c = 0; c < channel; c++) {
52 int32_t src_idx = b * plane * channel + p * channel + c;
53 int32_t dst_idx = b * channel * plane + c * plane + p;
54 dst_data[dst_idx] = src_data[src_idx];
55 }
56 }
57 }
58 return RET_OK;
59 }
60
61 template <typename T>
NCHW2NHWC(T * src_data,T * dst_data,std::vector<int32_t> shape)62 STATUS NCHW2NHWC(T *src_data, T *dst_data, std::vector<int32_t> shape) {
63 if (shape.size() != kDims4) {
64 MS_LOG(ERROR) << "The dim should be 4.";
65 return RET_ERROR;
66 }
67 int32_t batch = shape.at(0);
68 int32_t channel = shape.at(1);
69 int32_t plane = shape.at(kAxis2) * shape.at(kAxis3);
70 for (int32_t b = 0; b < batch; b++) {
71 for (int32_t c = 0; c < channel; c++) {
72 for (int32_t p = 0; p < plane; p++) {
73 int32_t src_idx = b * channel * plane + c * plane + p;
74 int32_t dst_idx = b * plane * channel + p * channel + c;
75 dst_data[dst_idx] = src_data[src_idx];
76 }
77 }
78 }
79 return RET_OK;
80 }
81
82 STATUS TransFilterFormat(const mindspore::api::TensorPtr &tensor, mindspore::Format src_format,
83 mindspore::Format dst_format);
84
85 void TransposeMatrix(float *matrix, int row, int col);
86 } // namespace dpico
87 } // namespace mindspore
88 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_DATA_TRANSPOSE_UTILS_H_
89