• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MS_DEVICE_SHAPE_TRANSFER_H_
17 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MS_DEVICE_SHAPE_TRANSFER_H_
18 #include <algorithm>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <set>
24 #include <utility>
25 #include <vector>
26 #include <numeric>
27 #include <optional>
28 #include <unordered_map>
29 #include "kernel/oplib/oplib.h"
30 #include "ir/dtype.h"
31 #include "kernel/kernel.h"
32 #include "ir/dtype/type.h"
33 #include "utils/shape_utils.h"
34 #include "include/backend/anf_runtime_algorithm.h"
35 #include "include/common/utils/anfalgo.h"
36 #include "utils/ms_utils.h"
37 #include "abstract/utils.h"
38 #include "runtime/device/convert_tensor_utils.h"
39 #include "include/common/utils/convert_utils.h"
40 #include "utils/log_adapter.h"
41 #include "include/common/utils/utils.h"
42 #include "include/backend/visible.h"
43 #include "mindapi/base/shape_vector.h"
44 
45 namespace mindspore {
46 namespace trans {
47 constexpr int64_t kAlign16 = 16;
48 enum kAxis4D : int { kN = 0, kC, kH, kW, kNchwDims };
49 enum Axis5D : int {
50   N_ncdhw = 0,
51   C_ncdhw,
52   D_ncdhw,
53   H_ncdhw,
54   W_ncdhw,
55   kNcdhw,
56   N_ndc1hwc0 = 0,
57   D_ndc1hwc0,
58   C1_ndc1hwc0,
59   H_ndc1hwc0,
60   W_ndc1hwc0,
61   C0_ndc1hwc0
62 };
63 using ShapeVector = std::vector<int64_t>;
64 using RangePair = std::vector<std::pair<int64_t, int64_t>>;
65 /**
66  * Args when trans node's data type
67  * */
68 struct TypeIdArgs {
69   const void *data;
70   int64_t src_shape_size;  // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
71   TypeId src_data_type;
72   TypeId dst_data_type;
73   size_t data_size;
74 };
75 
76 /**
77  * Args when trans node's data at host
78  * */
79 struct FormatArgs {
80   const void *data;
81   const size_t device_size;
82   std::string host_format;
83   std::string device_format;
84   ShapeVector host_shape;
85   ShapeVector device_shape;
86   TypeId src_data_type;
87 };
88 
89 /**
90  * Trans data type at host from src type to dst type
91  * */
92 class DataTypeTransfer {
93  public:
94   DataTypeTransfer() = default;
95   ~DataTypeTransfer() = default;
96   bool TransDataType(const TypeIdArgs &args, void *result) const;
97 
98  private:
99   enum class DataTypeTransMode {
100     FROM_BOOL_TO_UINT8,
101     FROM_BOOL_TO_INT32,
102     FROM_BOOL_TO_FLOAT16,
103     FROM_BOOL_TO_FLOAT,
104     FROM_INT8_TO_INT32,
105     FROM_INT8_TO_FLOAT,
106     FROM_INT8_TO_FLOAT16,
107     FROM_UINT8_TO_INT32,
108     FROM_UINT8_TO_FLOAT16,
109     FROM_UINT8_TO_FLOAT,
110     FROM_UINT16_TO_INT32,
111     FROM_INT16_TO_INT32,
112     FROM_INT16_TO_INT64,
113     FROM_INT32_TO_BOOL,
114     FROM_INT32_TO_INT8,
115     FROM_INT32_TO_UINT8,
116     FROM_INT32_TO_INT16,
117     FROM_INT32_TO_UINT16,
118     FROM_INT32_TO_INT64,
119     FROM_INT32_TO_FLOAT16,
120     FROM_INT32_TO_FLOAT,
121     FROM_INT64_TO_INT16,
122     FROM_INT64_TO_INT32,
123     FROM_FLOAT16_TO_UINT8,
124     FROM_FLOAT16_TO_INT32,
125     FROM_FLOAT16_TO_FLOAT,
126     FROM_FLOAT_TO_INT32,
127     FROM_FLOAT_TO_FLOAT16,
128     FROM_FLOAT_TO_BFLOAT16,
129     FROM_FLOAT32_TO_FLOAT64,
130     FROM_FLOAT64_TO_FLOAT32
131   };
132   const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map = {
133     {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32},
134     {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64},
135     {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), DataTypeTransMode::FROM_FLOAT_TO_FLOAT16},
136     {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeBFloat16), DataTypeTransMode::FROM_FLOAT_TO_BFLOAT16},
137     {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT_TO_INT32},
138     {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT16_TO_FLOAT},
139     {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT16_TO_INT32},
140     {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), DataTypeTransMode::FROM_FLOAT16_TO_UINT8},
141     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), DataTypeTransMode::FROM_INT32_TO_FLOAT},
142     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), DataTypeTransMode::FROM_INT32_TO_FLOAT16},
143     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), DataTypeTransMode::FROM_INT32_TO_UINT8},
144     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), DataTypeTransMode::FROM_INT32_TO_INT8},
145     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt16), DataTypeTransMode::FROM_INT32_TO_INT16},
146     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt16), DataTypeTransMode::FROM_INT32_TO_UINT16},
147     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt64), DataTypeTransMode::FROM_INT32_TO_INT64},
148     {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), DataTypeTransMode::FROM_INT32_TO_BOOL},
149     {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_UINT8_TO_FLOAT},
150     {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), DataTypeTransMode::FROM_UINT8_TO_INT32},
151     {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_UINT8_TO_FLOAT16},
152     {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_INT8_TO_FLOAT},
153     {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_INT8_TO_FLOAT16},
154     {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), DataTypeTransMode::FROM_INT8_TO_INT32},
155     {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt16), DataTypeTransMode::FROM_INT64_TO_INT16},
156     {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), DataTypeTransMode::FROM_INT64_TO_INT32},
157     {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), DataTypeTransMode::FROM_UINT16_TO_INT32},
158     {std::pair<TypeId, TypeId>(kNumberTypeInt16, kNumberTypeInt32), DataTypeTransMode::FROM_INT16_TO_INT32},
159     {std::pair<TypeId, TypeId>(kNumberTypeInt16, kNumberTypeInt64), DataTypeTransMode::FROM_INT16_TO_INT64},
160     {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), DataTypeTransMode::FROM_BOOL_TO_INT32},
161     {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), DataTypeTransMode::FROM_BOOL_TO_FLOAT},
162     {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), DataTypeTransMode::FROM_BOOL_TO_UINT8},
163     {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), DataTypeTransMode::FROM_BOOL_TO_FLOAT16}};
164 
165   bool CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const;
166 };
167 
168 /**
169  * Trans host shape to device shape according to node's format
170  * */
171 class BACKEND_EXPORT DeviceShapeTransfer {
172  public:
173   DeviceShapeTransfer() = default;
174   ~DeviceShapeTransfer() = default;
175   ShapeVector GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format, const AnfNodePtr &node,
176                                      size_t index, const TypeId &type, bool is_output = true) const;
177 
178   ShapeVector GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format, const TypeId &type,
179                                      int64_t groups = 1,
180                                      const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) const;
181 
182  private:
183   ShapeVector GetAttrInputAndHiddenSize(const AnfNodePtr &node) const;
184   std::optional<ShapeVector> GetFixedDeviceShape(const ShapeVector &, const AnfNodePtr &node, size_t index,
185                                                  bool is_output = true) const;
186   ShapeVector TransCore(const ShapeVector &shape, const std::string &format, const TypeId &type, int64_t groups = 1,
187                         const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) const;
188 
189   // trans functions
190   static ShapeVector NCHWDeviceShape(const ShapeVector &shape, const TypeId &);
191   static ShapeVector NHWCDeviceShape(const ShapeVector &shape, const TypeId &);
192   static ShapeVector HWCNDeviceShape(const ShapeVector &shape, const TypeId &);
193   static ShapeVector NCDHWDeviceShape(const ShapeVector &shape, const TypeId &);
194   static ShapeVector NC1HWC04DeviceShape(const ShapeVector &shape, const TypeId &);
195   static ShapeVector FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &);
196   static ShapeVector ChannelLastDeviceShape(const ShapeVector &shape, const TypeId &);
197   static ShapeVector FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &);
198   static ShapeVector FRAC_ZDeviceShape(const ShapeVector &shape, const TypeId &type);
199   static ShapeVector FRAC_NZDeviceShape(const ShapeVector &shape, const TypeId &type);
200   static ShapeVector NC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type);
201   static ShapeVector NDC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type);
202   static ShapeVector FRAC_Z3DDeviceShape(const ShapeVector &shape, const TypeId &type);
203   static ShapeVector C1HWNCOC0DeviceShape(const ShapeVector &shape, const TypeId &type);
204   static ShapeVector NDRNNBiasDeviceShape(const ShapeVector &shape, const TypeId &type, int64_t hidden_size = 16);
205   static ShapeVector FRAC_ZDeviceShapeWithGroups(const ShapeVector &shape, const TypeId &type, int64_t groups = 1);
206   static ShapeVector FRAC_ZN_RNNDeviceShape(const ShapeVector &shape, const TypeId &type,
207                                             const ShapeVector &input_hidden_size = {kAlign16, kAlign16});
208 };
209 
210 /**
211  * Trans data at host according to the node's format
212  * */
213 class FormatTransfer {
214  public:
215   FormatTransfer() = default;
216   ~FormatTransfer() = default;
217 
218   bool TransDataByFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index, bool is_forward);
219   bool TransDataForwardCore(const FormatArgs &args, void *result, int64_t groups = 1);
220   bool TransDataBackwordCore(const FormatArgs &args, void *result, int64_t groups = 1);
221 
222  private:
223   using TransferCore = std::function<bool(const FormatArgs &, void *)>;
224   // fp map
225   const std::map<std::string, TransferCore> format_trans_fp_map = {{kOpFormat_HWCN, NCHW_TO_4D},
226                                                                    {kOpFormat_NHWC, NCHW_TO_4D},
227                                                                    {kOpFormat_FRAC_Z, NCHW_TO_FRAC_Z},
228                                                                    {kOpFormat_FRAC_NZ, NCHW_TO_FRAC_NZ},
229                                                                    {kOpFormat_NC1HWC0, NCHW_TO_NC1HWC0},
230                                                                    {kOpFormat_NDC1HWC0, NCDHW_TO_NDC1HWC0},
231                                                                    {kOpFormat_C1HWNCoC0, NCHW_TO_C1HWNCOC0},
232                                                                    {kOpFormat_NC1HWC0_C04, NCHW_TO_NC1HWC04},
233                                                                    {kOpFormat_FRACTAL_Z_3D, NCDHW_TO_FRAC_Z3D},
234                                                                    {kOpFormat_FRACTAL_Z_C04, NCHW_TO_FRAC_ZC04}};
235   // bp map
236   const std::map<std::string, TransferCore> format_trans_bp_map = {{kOpFormat_HWCN, TO_NCHW},
237                                                                    {kOpFormat_NHWC, TO_NCHW},
238                                                                    {kOpFormat_FRAC_Z, FRAC_Z_TO_NCHW},
239                                                                    {kOpFormat_FRAC_NZ, FRAC_NZ_TO_NCHW},
240                                                                    {kOpFormat_NC1HWC0, NC1HWC0_TO_NCHW},
241                                                                    {kOpFormat_NDC1HWC0, NDC1HWC0_TO_NCDHW},
242                                                                    {kOpFormat_C1HWNCoC0, C1HWNCOC0_TO_NCHW},
243                                                                    {kOpFormat_NC1HWC0_C04, NC1HWC04_TO_NCHW},
244                                                                    {kOpFormat_FRACTAL_Z_3D, FRAC_Z3D_TO_NCDHW}};
245 
246   static bool CheckArgs(const FormatArgs &args, int64_t *size);
247   static bool TransShapeToHW_NZ(const ShapeVector &host_shape, ShapeVector *hw_shape);
248   // HOST TO DEVICE
249   static bool NCHW_TO_4D(const FormatArgs &args, void *result);
250   static bool NCHW_TO_FRAC_Z(const FormatArgs &args, void *result);
251   static bool NCHW_TO_NC1HWC0(const FormatArgs &args, void *result);
252   static bool NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result);
253   static bool NCHW_TO_NC1HWC04(const FormatArgs &args, void *result);
254   static bool NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result);
255   static bool NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result);
256   static bool NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result);
257   static bool NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result);
258   static bool NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups);
259 
260   // DEVICE TO HOST
261   static bool TO_NCHW(const FormatArgs &args, void *result);
262   static bool FRAC_Z_TO_NCHW(const FormatArgs &args, void *result);
263   static bool FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result);
264   static bool NC1HWC0_TO_NCHW(const FormatArgs &args, void *result);
265   static bool NC1HWC04_TO_NCHW(const FormatArgs &args, void *result);
266   static bool C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result);
267   static bool FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result);
268   static bool NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result);
269   static bool FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups);
270 
271   // common check_func
272   static int64_t Common4DCheck(const FormatArgs &args);
273 };
274 
275 /**
276  * Range trans function
277  * */
278 class BACKEND_EXPORT ShapeRangeTransfer {
279  public:
280   ShapeRangeTransfer() = default;
281   ~ShapeRangeTransfer() = default;
282   RangePair GetRealRange(const RangePair &ori_range, const std::string &format, const TypeId &type,
283                          const std::string &padding_str = {""}) const;
284 
285  private:
286   static RangePair NHWCRange(const RangePair &ori_range, const TypeId &);
287   static RangePair HWCNRange(const RangePair &ori_range, const TypeId &);
288   static RangePair NC1HWC04Range(const RangePair &ori_range, const TypeId &);
289   static RangePair FRAC_ZC04Range(const RangePair &ori_range, const TypeId &);
290   static RangePair FRAC_ZN_LSTMRange(const RangePair &ori_range, const TypeId &);
291   static RangePair FRAC_ZRange(const RangePair &ori_range, const TypeId &type);
292   static RangePair FRAC_NZRange(const RangePair &ori_range, const TypeId &type);
293   static RangePair NC1HWC0Range(const RangePair &ori_range, const TypeId &type);
294   static RangePair NDC1HWC0Range(const RangePair &ori_range, const TypeId &type);
295   static RangePair C1HWNCOC0Range(const RangePair &ori_range, const TypeId &type);
296   static RangePair FRAC_Z_3DRange(const RangePair &ori_range, const TypeId &type);
297 };
298 
299 /**
300  * If you want extend format, make sure it has a data trans function at host in class
301  * 'FormatTransfer.format_trans_fp_map'
302  * */
303 static const std::set<std::string> kFormatWithTransFunc = {
304   kOpFormat_HWCN,     kOpFormat_NHWC,      kOpFormat_FRAC_Z,      kOpFormat_FRAC_NZ,      kOpFormat_NC1HWC0,
305   kOpFormat_NDC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_3D, kOpFormat_FRACTAL_Z_C04};
306 
307 /**
308  * Interface of datatype trans
309  * */
310 BACKEND_EXPORT bool TransDataType(const TypeIdArgs &args, void *result);
311 
312 /**
313  * Interface of data format trans from host to device
314  * */
315 BACKEND_EXPORT bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index);
316 
317 /**
318  * Interface of data format trans from host to device
319  * */
320 BACKEND_EXPORT bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups = 1);
321 
322 /**
323  * Interface of data format trans from device to host
324  * */
325 BACKEND_EXPORT bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node,
326                                                 size_t index);
327 
328 /**
329  * 4D reshape type trans, trans reshape_type from string to int
330  * */
331 BACKEND_EXPORT void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
332 
333 /**
334  * 5D reshape type trans, trans reshape_type from string to int
335  * */
336 BACKEND_EXPORT void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
337 
338 /**
339  * Get shape after padding
340  * */
341 BACKEND_EXPORT ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
342 
343 /**
344  *  If need padding
345  * */
346 BACKEND_EXPORT bool IsNeedPadding(const std::string &format, const ShapeVector &shape);
347 
348 /**
349  * Padding shape to 5D by default mode
350  * */
351 template <typename T>
352 std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape, const AnfNodePtr &node = nullptr) {
353   if (shape.size() >= kDim5) {
354     return shape;
355   }
356   std::vector<T> shape_5d(kNcdhw, 1);
357   switch (shape.size()) {
358     case N_ncdhw:
359       return shape_5d;
360     case C_ncdhw:
361       shape_5d[C_ncdhw] = shape[N_ncdhw];
362       break;
363     case D_ncdhw:
364       shape_5d[C_ncdhw] = shape[N_ncdhw];
365       shape_5d[D_ncdhw] = shape[C_ncdhw];
366       break;
367     case H_ncdhw:
368       shape_5d[C_ncdhw] = shape[N_ncdhw];
369       shape_5d[D_ncdhw] = shape[C_ncdhw];
370       shape_5d[H_ncdhw] = shape[D_ncdhw];
371       break;
372     case W_ncdhw:
373       shape_5d[C_ncdhw] = shape[N_ncdhw];
374       shape_5d[D_ncdhw] = shape[C_ncdhw];
375       shape_5d[H_ncdhw] = shape[D_ncdhw];
376       shape_5d[W_ncdhw] = shape[H_ncdhw];
377       break;
378     default:
379       auto node_info = (node != nullptr) ? ". Node: " + node->fullname_with_scope() : " .";
380       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected shape :" << shape << node_info;
381   }
382   return shape_5d;
383 }
384 
385 /**
386  * Padding shape to 4D by default mode
387  * */
388 template <typename T>
389 std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape, const AnfNodePtr &node = nullptr) {
390   std::vector<T> shape_4d(kNchwDims, 1);
391   switch (shape.size()) {
392     case kN:
393       return shape_4d;
394     case kC:
395       shape_4d[kC] = shape[kN];
396       break;
397     case kH:
398       shape_4d[kC] = shape[kN];
399       shape_4d[kH] = shape[kC];
400       break;
401     case kW:
402       shape_4d[kC] = shape[kN];
403       shape_4d[kH] = shape[kC];
404       shape_4d[kW] = shape[kH];
405       break;
406     case kNchwDims:
407       return shape;
408     default:
409       auto node_info = (node != nullptr) ? ". Node: " + node->fullname_with_scope() : " .";
410       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected shape : " << shape << node_info;
411   }
412   return shape_4d;
413 }
414 
415 /**
416  * Padding shape to 5D according to reshape type
417  * */
418 template <typename T>
419 std::vector<T> PaddingShapeTo5d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
420   std::vector<Axis5D> padding_axis;
421   StringToAxisVector5D(padding_str, &padding_axis);
422   if (padding_axis.empty() || shape.size() > padding_axis.size()) {
423     return PaddingShapeTo5dDefault(shape);
424   }
425   std::vector<T> shape_5d(kNcdhw, 1);
426   for (size_t index = 0; index < shape.size(); index++) {
427     shape_5d[padding_axis[index]] = shape[index];
428   }
429   return shape_5d;
430 }
431 
432 /**
433  * Padding shape to 4D according to reshape type
434  * */
435 template <typename T>
436 std::vector<T> PaddingShapeTo4d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
437   std::vector<Axis> padding_axis;
438   StringToAxisVector4D(padding_str, &padding_axis);
439   if (padding_axis.empty() || shape.size() > padding_axis.size()) {
440     return PaddingShapeTo4dDefault(shape);
441   }
442   std::vector<T> shape_4d(kNchwDims, 1);
443   for (size_t index = 0; index < shape.size(); index++) {
444     shape_4d[padding_axis[index]] = shape[index];
445   }
446   return shape_4d;
447 }
448 
449 /**
450  * Interface of padding shape
451  * */
452 template <typename T>
453 std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &format, const std::string &pad_index = {""},
454                             const AnfNodePtr &node = nullptr) {
455   if (node != nullptr) {
456     MS_LOG(DEBUG) << "Start padding shape for node: [" << node->fullname_with_scope() << "], format: " << format
457                   << ", detail info: " << node->DebugString();
458   }
459 
460   if (IsOneOf3DFormat(format)) {
461     if (shape.size() >= kDim5) {
462       return shape;
463     }
464     if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) {
465       return {-1, -1, -1, -1, -1};
466     }
467     return PaddingShapeTo5d(shape, pad_index);
468   }
469 
470   if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) {
471     return {-1, -1, -1, -1};
472   }
473   return PaddingShapeTo4d(shape, pad_index);
474 }
475 
476 /**
477  * Interface of transform pad_index string to AxisVector
478  * */
479 template <typename T>
480 std::vector<int> StringToAxisVector(const std::vector<T> &shape, const std::string &format,
481                                     const std::string &pad_index = {""}, const AnfNodePtr &node = nullptr) {
482   if (node != nullptr) {
483     MS_LOG(DEBUG) << "Start transform  pad_index to axis_vecor for node: [" << node->fullname_with_scope()
484                   << "], format: " << format << ", detail info: " << node->DebugString();
485   }
486 
487   std::vector<int> padding_axis;
488   if (IsOneOf3DFormat(format)) {
489     if (shape.size() >= kDim5) {
490       return padding_axis;
491     }
492     std::vector<Axis5D> padding_axis_5d;
493     StringToAxisVector5D(pad_index, &padding_axis_5d);
494 
495     if (padding_axis_5d.empty() || shape.size() != padding_axis_5d.size()) {
496       for (int index = 0; index < static_cast<int>(shape.size()); ++index) {
497         padding_axis.push_back(index);
498       }
499     } else {
500       (void)std::transform(padding_axis_5d.begin(), padding_axis_5d.end(), std::back_inserter(padding_axis),
501                            [](Axis5D x) { return static_cast<int>(x); });
502     }
503   } else {
504     std::vector<Axis> padding_axis_4d;
505     StringToAxisVector4D(pad_index, &padding_axis_4d);
506 
507     if (padding_axis_4d.empty() || shape.size() != padding_axis_4d.size()) {
508       for (int index = 0; index < static_cast<int>(shape.size()); ++index) {
509         padding_axis.push_back(index);
510       }
511     } else {
512       (void)std::transform(padding_axis_4d.begin(), padding_axis_4d.end(), std::back_inserter(padding_axis),
513                            [](Axis x) { return static_cast<int>(x); });
514     }
515   }
516 
517   return padding_axis;
518 }
519 
520 /**
521  * Interface of device shape trance
522  * */
523 template <typename T>
524 std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
525                                   size_t index, TypeId type, bool is_output = true) {
526   ShapeVector shape_before;
527   (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
528                        [](T num) { return static_cast<int64_t>(num); });
529   DeviceShapeTransfer deviceShapeTransfer;
530   auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, node, index, type, is_output);
531   std::vector<T> out_shape;
532   (void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
533                        [](int64_t num) { return static_cast<T>(num); });
534   return out_shape;
535 }
536 
537 template <typename T>
538 std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, TypeId type,
539                                   int64_t groups = 1, const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) {
540   ShapeVector shape_before;
541   (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
542                        [](T num) { return static_cast<int64_t>(num); });
543   DeviceShapeTransfer deviceShapeTransfer;
544   auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, type, groups, input_hidden_size);
545   std::vector<T> out_shape;
546   (void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
547                        [](int64_t num) { return static_cast<T>(num); });
548   return out_shape;
549 }
550 
551 struct FormatInfo {
FormatInfoFormatInfo552   FormatInfo(std::string format, bool is_padded) : baseFormat(format), isPadded(is_padded) {}
553   std::string baseFormat = kOpFormat_ND;
554   bool isPadded = false;
555 };
556 
557 class BACKEND_EXPORT FormatHelper {
558  public:
559   static FormatHelper &GetInstance() noexcept;
560   const std::string GetBaseFormat(const std::string &format);
561   bool IsBaseFormatType(const std::string &format);
562   bool IsPadded(const std::string &format);
563   void Clear();
564 
565  private:
FormatHelper()566   FormatHelper() { InitInfo(); }
567   ~FormatHelper() = default;
568   void InitInfo();
569 
570   std::unordered_map<std::string, FormatInfo> info;
571 };
572 }  // namespace trans
573 }  // namespace mindspore
574 
575 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_MS_DEVICE_SHAPE_TRANSFER_H_
576