• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_COMMON_TRANS_H
17 #define MINDSPORE_CCSRC_COMMON_TRANS_H
18 
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include "ir/dtype.h"
27 #include "backend/kernel_compiler/kernel.h"
28 #include "ir/dtype/type.h"
29 #include "utils/shape_utils.h"
30 #include "backend/session/anf_runtime_algorithm.h"
31 
32 namespace mindspore {
33 namespace trans {
34 constexpr int64_t kAlign16 = 16;
35 
36 enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
37 
38 enum Axis5D : int {
39   N_ncdhw = 0,
40   C_ncdhw,
41   D_ncdhw,
42   H_ncdhw,
43   W_ncdhw,
44   kNcdhw,
45   N_ndc1hwc0 = 0,
46   D_ndc1hwc0,
47   C1_ndc1hwc0,
48   H_ndc1hwc0,
49   W_ndc1hwc0,
50   C0_ndc1hwc0
51 };
52 
53 struct TypeIdArgs {
54   const void *data;
55   size_t host_shape_size;  // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
56   TypeId host_data_type;
57   TypeId device_data_type;
58   size_t data_size;
59 };
60 
61 struct FormatArgs {
62   const void *data;
63   const size_t device_size;
64   std::string host_format;
65   std::string device_format;
66   std::vector<size_t> host_shape;
67   std::vector<size_t> device_shape;
68   TypeId src_data_type;
69 };
70 
71 int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index);
72 std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node);
73 void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
74 void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
75 ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
76 bool IsNeedPadding(const std::string &format, const size_t shape_size);
77 int64_t GetNodeGroups(const AnfNodePtr &node);
78 std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
79                                        const int64_t groups = 1,
80                                        const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
81 std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
82                                         const int64_t groups = 1,
83                                         const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
84 template <typename T>
85 std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
86                                   const size_t index, bool is_output = true) {
87   int64_t groups = 1;
88   if (format == kOpFormat_FRAC_Z) {
89     groups = GetAttrGroups(node, index);
90   }
91   std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
92   if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
93     input_hidden_size = GetAttrInputAndHiddenSize(node);
94   }
95   if (node != nullptr) {
96     MS_LOG(DEBUG) << "Start trans infer shape to device shape for node: " << node->DebugString()
97                   << ", format: " << format;
98   }
99 
100   return TransShapeToDevice(shape, format, groups, input_hidden_size);
101 }
102 bool TransDataType(const TypeIdArgs &args, void *result);
103 bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);
104 bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
105 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups = 1);
106 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
107 
108 // host to device
109 bool NchwTo4D(const FormatArgs &args, void *result);
110 bool NchwToFracZ(const FormatArgs &args, void *result);
111 bool NchwToFracNz(const FormatArgs &args, void *result);
112 bool NchwToNc1hwc0(const FormatArgs &args, void *result);
113 bool NcdhwToFracZ3D(const FormatArgs &args, void *result);
114 bool NchwToFracZc04(const FormatArgs &args, void *result);
115 bool NchwToNc1hwc04(const FormatArgs &args, void *result);
116 bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
117 bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result);
118 bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups);
119 
120 // device to host
121 bool ToNchw(const FormatArgs &args, void *result);
122 bool FracZToNchw(const FormatArgs &args, void *result);
123 bool FracNzToNchw(const FormatArgs &args, void *result);
124 bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
125 bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
126 bool FracZ3DToNcdhw(const FormatArgs &args, void *result);
127 bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
128 bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
129 bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups);
130 using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
131 const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
132   {kOpFormat_FRAC_Z, NchwToFracZ},           {kOpFormat_FRAC_NZ, NchwToFracNz},
133   {kOpFormat_NC1HWC0, NchwToNc1hwc0},        {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
134   {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
135   {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0},     {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}};
136 
137 template <typename T>
PaddingShapeTo5dDefault(const std::vector<T> & shape)138 std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) {
139   if (shape.size() >= kNcdhw) {
140     return shape;
141   }
142   std::vector<T> shape_5d(kNcdhw, 1);
143   switch (shape.size()) {
144     case N_ncdhw:
145       return shape_5d;
146     case C_ncdhw:
147       shape_5d[C_ncdhw] = shape[N_ncdhw];
148       break;
149     case D_ncdhw:
150       shape_5d[C_ncdhw] = shape[N_ncdhw];
151       shape_5d[D_ncdhw] = shape[C_ncdhw];
152       break;
153     case H_ncdhw:
154       shape_5d[C_ncdhw] = shape[N_ncdhw];
155       shape_5d[D_ncdhw] = shape[C_ncdhw];
156       shape_5d[H_ncdhw] = shape[D_ncdhw];
157       break;
158     case W_ncdhw:
159       shape_5d[C_ncdhw] = shape[N_ncdhw];
160       shape_5d[D_ncdhw] = shape[C_ncdhw];
161       shape_5d[H_ncdhw] = shape[D_ncdhw];
162       shape_5d[W_ncdhw] = shape[H_ncdhw];
163       break;
164     default:
165       MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
166   }
167   return shape_5d;
168 }
169 
170 template <typename T>
PaddingShapeTo4dDefault(const std::vector<T> & shape)171 std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
172   std::vector<T> shape_4d(kNchwDims, 1);
173   switch (shape.size()) {
174     case kN:
175       return shape_4d;
176     case kC:
177       shape_4d[kC] = shape[kN];
178       break;
179     case kH:
180       shape_4d[kC] = shape[kN];
181       shape_4d[kH] = shape[kC];
182       break;
183     case kW:
184       shape_4d[kC] = shape[kN];
185       shape_4d[kH] = shape[kC];
186       shape_4d[kW] = shape[kH];
187       break;
188     case kNchwDims:
189       std::copy(shape.begin(), shape.end(), shape_4d.begin());
190       break;
191     default:
192       MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
193   }
194   return shape_4d;
195 }
196 
197 template <typename T>
198 std::vector<T> PaddingShapeTo5d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
199   std::vector<Axis5D> padding_axis;
200   StringToAxisVector5D(padding_str, &padding_axis);
201   if (padding_axis.empty() || shape.size() != padding_axis.size()) {
202     return PaddingShapeTo5dDefault(shape);
203   }
204   std::vector<T> shape_5d(kNcdhw, 1);
205   for (size_t index = 0; index < padding_axis.size(); index++) {
206     shape_5d[padding_axis[index]] = shape[index];
207   }
208   return shape_5d;
209 }
210 
211 template <typename T>
212 std::vector<T> PaddingShapeTo4d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
213   std::vector<Axis> padding_axis;
214   StringToAxisVector4D(padding_str, &padding_axis);
215   if (padding_axis.empty() || shape.size() != padding_axis.size()) {
216     return PaddingShapeTo4dDefault(shape);
217   }
218   std::vector<T> shape_4d(kNchwDims, 1);
219   for (size_t index = 0; index < padding_axis.size(); index++) {
220     shape_4d[padding_axis[index]] = shape[index];
221   }
222   return shape_4d;
223 }
224 
225 template <typename T>
226 std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &format,
227                             const std::string &pad_index = {""}) {
228   std::vector<T> host_shape;
229   if (k3DFormatSet.find(format) != k3DFormatSet.end()) {
230     if (shape.size() >= kNcdhw) {
231       return shape;
232     }
233     host_shape = trans::PaddingShapeTo5d(shape, pad_index);
234   } else {
235     host_shape = trans::PaddingShapeTo4d(shape, pad_index);
236   }
237   return host_shape;
238 }
239 }  // namespace trans
240 }  // namespace mindspore
241 
242 #endif  // MINDSPORE_CCSRC_COMMON_TRANS_H
243