• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_DATA_TRANSPOSE_UTILS_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_DATA_TRANSPOSE_UTILS_H_
19 
20 #include <vector>
21 #include "mindapi/base/format.h"
22 #include "mindapi/ir/tensor.h"
23 #include "mindapi/base/logging.h"
24 #include "include/errorcode.h"
25 #include "common/op_enum.h"
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_NO_CHANGE;
28 using mindspore::lite::RET_OK;
29 using mindspore::lite::STATUS;
30 namespace mindspore {
31 namespace dpico {
32 inline const std::vector<int> kNH2NC = {0, 3, 1, 2};
33 inline const std::vector<int> kNC2NH = {0, 2, 3, 1};
34 enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
35 struct TransTypePair {
36   FormatTransNodeType pre_;
37   FormatTransNodeType post_;
TransTypePairTransTypePair38   TransTypePair() : pre_(kNONE), post_(kNONE) {}
39 };
40 template <typename T>
NHWC2NCHW(T * src_data,T * dst_data,std::vector<int32_t> shape)41 STATUS NHWC2NCHW(T *src_data, T *dst_data, std::vector<int32_t> shape) {
42   if (shape.size() != kDims4) {
43     MS_LOG(ERROR) << "The dim should be 4.";
44     return RET_ERROR;
45   }
46   int32_t batch = shape.at(0);
47   int32_t plane = shape.at(kAxis1) * shape.at(kAxis2);
48   int32_t channel = shape.at(kAxis3);
49   for (int32_t b = 0; b < batch; b++) {
50     for (int32_t p = 0; p < plane; p++) {
51       for (int32_t c = 0; c < channel; c++) {
52         int32_t src_idx = b * plane * channel + p * channel + c;
53         int32_t dst_idx = b * channel * plane + c * plane + p;
54         dst_data[dst_idx] = src_data[src_idx];
55       }
56     }
57   }
58   return RET_OK;
59 }
60 
61 template <typename T>
NCHW2NHWC(T * src_data,T * dst_data,std::vector<int32_t> shape)62 STATUS NCHW2NHWC(T *src_data, T *dst_data, std::vector<int32_t> shape) {
63   if (shape.size() != kDims4) {
64     MS_LOG(ERROR) << "The dim should be 4.";
65     return RET_ERROR;
66   }
67   int32_t batch = shape.at(0);
68   int32_t channel = shape.at(1);
69   int32_t plane = shape.at(kAxis2) * shape.at(kAxis3);
70   for (int32_t b = 0; b < batch; b++) {
71     for (int32_t c = 0; c < channel; c++) {
72       for (int32_t p = 0; p < plane; p++) {
73         int32_t src_idx = b * channel * plane + c * plane + p;
74         int32_t dst_idx = b * plane * channel + p * channel + c;
75         dst_data[dst_idx] = src_data[src_idx];
76       }
77     }
78   }
79   return RET_OK;
80 }
81 
82 STATUS TransFilterFormat(const mindspore::api::TensorPtr &tensor, mindspore::Format src_format,
83                          mindspore::Format dst_format);
84 
85 void TransposeMatrix(float *matrix, int row, int col);
86 }  // namespace dpico
87 }  // namespace mindspore
88 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_DATA_TRANSPOSE_UTILS_H_
89