1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MS_DEVICE_SHAPE_TRANSFER_H_ 17 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MS_DEVICE_SHAPE_TRANSFER_H_ 18 #include <algorithm> 19 #include <functional> 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <set> 24 #include <utility> 25 #include <vector> 26 #include <numeric> 27 #include <optional> 28 #include <unordered_map> 29 #include "kernel/oplib/oplib.h" 30 #include "ir/dtype.h" 31 #include "kernel/kernel.h" 32 #include "ir/dtype/type.h" 33 #include "utils/shape_utils.h" 34 #include "include/backend/anf_runtime_algorithm.h" 35 #include "include/common/utils/anfalgo.h" 36 #include "utils/ms_utils.h" 37 #include "abstract/utils.h" 38 #include "runtime/device/convert_tensor_utils.h" 39 #include "include/common/utils/convert_utils.h" 40 #include "utils/log_adapter.h" 41 #include "include/common/utils/utils.h" 42 #include "include/backend/visible.h" 43 #include "mindapi/base/shape_vector.h" 44 45 namespace mindspore { 46 namespace trans { 47 constexpr int64_t kAlign16 = 16; 48 enum kAxis4D : int { kN = 0, kC, kH, kW, kNchwDims }; 49 enum Axis5D : int { 50 N_ncdhw = 0, 51 C_ncdhw, 52 D_ncdhw, 53 H_ncdhw, 54 W_ncdhw, 55 kNcdhw, 56 N_ndc1hwc0 = 0, 57 D_ndc1hwc0, 58 C1_ndc1hwc0, 59 H_ndc1hwc0, 60 W_ndc1hwc0, 61 C0_ndc1hwc0 62 }; 63 using ShapeVector = std::vector<int64_t>; 64 using RangePair = std::vector<std::pair<int64_t, int64_t>>; 65 /** 66 * Args when trans node's data type 67 * */ 68 struct TypeIdArgs { 69 const void *data; 70 int64_t src_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d 71 TypeId src_data_type; 72 TypeId dst_data_type; 73 size_t data_size; 74 }; 75 76 /** 77 * Args when trans node's data at host 78 * */ 79 struct FormatArgs { 80 const void *data; 81 const size_t device_size; 82 std::string host_format; 83 std::string device_format; 84 ShapeVector host_shape; 85 ShapeVector device_shape; 86 TypeId src_data_type; 87 }; 88 89 /** 90 * Trans data type at host from src type to dst type 91 * */ 92 class DataTypeTransfer { 93 public: 94 DataTypeTransfer() = default; 95 ~DataTypeTransfer() = default; 96 bool TransDataType(const TypeIdArgs &args, void *result) const; 97 98 private: 99 enum class DataTypeTransMode { 100 FROM_BOOL_TO_UINT8, 101 FROM_BOOL_TO_INT32, 102 FROM_BOOL_TO_FLOAT16, 103 FROM_BOOL_TO_FLOAT, 104 FROM_INT8_TO_INT32, 105 FROM_INT8_TO_FLOAT, 106 FROM_INT8_TO_FLOAT16, 107 FROM_UINT8_TO_INT32, 108 FROM_UINT8_TO_FLOAT16, 109 FROM_UINT8_TO_FLOAT, 110 FROM_UINT16_TO_INT32, 111 FROM_INT16_TO_INT32, 112 FROM_INT16_TO_INT64, 113 FROM_INT32_TO_BOOL, 114 FROM_INT32_TO_INT8, 115 FROM_INT32_TO_UINT8, 116 FROM_INT32_TO_INT16, 117 FROM_INT32_TO_UINT16, 118 FROM_INT32_TO_INT64, 119 FROM_INT32_TO_FLOAT16, 120 FROM_INT32_TO_FLOAT, 121 FROM_INT64_TO_INT16, 122 FROM_INT64_TO_INT32, 123 FROM_FLOAT16_TO_UINT8, 124 FROM_FLOAT16_TO_INT32, 125 FROM_FLOAT16_TO_FLOAT, 126 FROM_FLOAT_TO_INT32, 127 FROM_FLOAT_TO_FLOAT16, 128 FROM_FLOAT_TO_BFLOAT16, 129 FROM_FLOAT32_TO_FLOAT64, 130 FROM_FLOAT64_TO_FLOAT32 131 }; 132 const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map = { 133 {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32}, 134 {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64}, 135 {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), DataTypeTransMode::FROM_FLOAT_TO_FLOAT16}, 136 {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeBFloat16), DataTypeTransMode::FROM_FLOAT_TO_BFLOAT16}, 137 {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT_TO_INT32}, 138 {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT16_TO_FLOAT}, 139 {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT16_TO_INT32}, 140 {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), DataTypeTransMode::FROM_FLOAT16_TO_UINT8}, 141 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), DataTypeTransMode::FROM_INT32_TO_FLOAT}, 142 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), DataTypeTransMode::FROM_INT32_TO_FLOAT16}, 143 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), DataTypeTransMode::FROM_INT32_TO_UINT8}, 144 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), DataTypeTransMode::FROM_INT32_TO_INT8}, 145 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt16), DataTypeTransMode::FROM_INT32_TO_INT16}, 146 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt16), DataTypeTransMode::FROM_INT32_TO_UINT16}, 147 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt64), DataTypeTransMode::FROM_INT32_TO_INT64}, 148 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), DataTypeTransMode::FROM_INT32_TO_BOOL}, 149 {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_UINT8_TO_FLOAT}, 150 {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), DataTypeTransMode::FROM_UINT8_TO_INT32}, 151 {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_UINT8_TO_FLOAT16}, 152 {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_INT8_TO_FLOAT}, 153 {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_INT8_TO_FLOAT16}, 154 {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), DataTypeTransMode::FROM_INT8_TO_INT32}, 155 {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt16), DataTypeTransMode::FROM_INT64_TO_INT16}, 156 {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), DataTypeTransMode::FROM_INT64_TO_INT32}, 157 {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), DataTypeTransMode::FROM_UINT16_TO_INT32}, 158 {std::pair<TypeId, TypeId>(kNumberTypeInt16, kNumberTypeInt32), DataTypeTransMode::FROM_INT16_TO_INT32}, 159 {std::pair<TypeId, TypeId>(kNumberTypeInt16, kNumberTypeInt64), DataTypeTransMode::FROM_INT16_TO_INT64}, 160 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), DataTypeTransMode::FROM_BOOL_TO_INT32}, 161 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), DataTypeTransMode::FROM_BOOL_TO_FLOAT}, 162 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), DataTypeTransMode::FROM_BOOL_TO_UINT8}, 163 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), DataTypeTransMode::FROM_BOOL_TO_FLOAT16}}; 164 165 bool CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const; 166 }; 167 168 /** 169 * Trans host shape to device shape according to node's format 170 * */ 171 class BACKEND_EXPORT DeviceShapeTransfer { 172 public: 173 DeviceShapeTransfer() = default; 174 ~DeviceShapeTransfer() = default; 175 ShapeVector GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format, const AnfNodePtr &node, 176 size_t index, const TypeId &type, bool is_output = true) const; 177 178 ShapeVector GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format, const TypeId &type, 179 int64_t groups = 1, 180 const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) const; 181 182 private: 183 ShapeVector GetAttrInputAndHiddenSize(const AnfNodePtr &node) const; 184 std::optional<ShapeVector> GetFixedDeviceShape(const ShapeVector &, const AnfNodePtr &node, size_t index, 185 bool is_output = true) const; 186 ShapeVector TransCore(const ShapeVector &shape, const std::string &format, const TypeId &type, int64_t groups = 1, 187 const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) const; 188 189 // trans functions 190 static ShapeVector NCHWDeviceShape(const ShapeVector &shape, const TypeId &); 191 static ShapeVector NHWCDeviceShape(const ShapeVector &shape, const TypeId &); 192 static ShapeVector HWCNDeviceShape(const ShapeVector &shape, const TypeId &); 193 static ShapeVector NCDHWDeviceShape(const ShapeVector &shape, const TypeId &); 194 static ShapeVector NC1HWC04DeviceShape(const ShapeVector &shape, const TypeId &); 195 static ShapeVector FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &); 196 static ShapeVector ChannelLastDeviceShape(const ShapeVector &shape, const TypeId &); 197 static ShapeVector FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &); 198 static ShapeVector FRAC_ZDeviceShape(const ShapeVector &shape, const TypeId &type); 199 static ShapeVector FRAC_NZDeviceShape(const ShapeVector &shape, const TypeId &type); 200 static ShapeVector NC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type); 201 static ShapeVector NDC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type); 202 static ShapeVector FRAC_Z3DDeviceShape(const ShapeVector &shape, const TypeId &type); 203 static ShapeVector C1HWNCOC0DeviceShape(const ShapeVector &shape, const TypeId &type); 204 static ShapeVector NDRNNBiasDeviceShape(const ShapeVector &shape, const TypeId &type, int64_t hidden_size = 16); 205 static ShapeVector FRAC_ZDeviceShapeWithGroups(const ShapeVector &shape, const TypeId &type, int64_t groups = 1); 206 static ShapeVector FRAC_ZN_RNNDeviceShape(const ShapeVector &shape, const TypeId &type, 207 const ShapeVector &input_hidden_size = {kAlign16, kAlign16}); 208 }; 209 210 /** 211 * Trans data at host according to the node's format 212 * */ 213 class FormatTransfer { 214 public: 215 FormatTransfer() = default; 216 ~FormatTransfer() = default; 217 218 bool TransDataByFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index, bool is_forward); 219 bool TransDataForwardCore(const FormatArgs &args, void *result, int64_t groups = 1); 220 bool TransDataBackwordCore(const FormatArgs &args, void *result, int64_t groups = 1); 221 222 private: 223 using TransferCore = std::function<bool(const FormatArgs &, void *)>; 224 // fp map 225 const std::map<std::string, TransferCore> format_trans_fp_map = {{kOpFormat_HWCN, NCHW_TO_4D}, 226 {kOpFormat_NHWC, NCHW_TO_4D}, 227 {kOpFormat_FRAC_Z, NCHW_TO_FRAC_Z}, 228 {kOpFormat_FRAC_NZ, NCHW_TO_FRAC_NZ}, 229 {kOpFormat_NC1HWC0, NCHW_TO_NC1HWC0}, 230 {kOpFormat_NDC1HWC0, NCDHW_TO_NDC1HWC0}, 231 {kOpFormat_C1HWNCoC0, NCHW_TO_C1HWNCOC0}, 232 {kOpFormat_NC1HWC0_C04, NCHW_TO_NC1HWC04}, 233 {kOpFormat_FRACTAL_Z_3D, NCDHW_TO_FRAC_Z3D}, 234 {kOpFormat_FRACTAL_Z_C04, NCHW_TO_FRAC_ZC04}}; 235 // bp map 236 const std::map<std::string, TransferCore> format_trans_bp_map = {{kOpFormat_HWCN, TO_NCHW}, 237 {kOpFormat_NHWC, TO_NCHW}, 238 {kOpFormat_FRAC_Z, FRAC_Z_TO_NCHW}, 239 {kOpFormat_FRAC_NZ, FRAC_NZ_TO_NCHW}, 240 {kOpFormat_NC1HWC0, NC1HWC0_TO_NCHW}, 241 {kOpFormat_NDC1HWC0, NDC1HWC0_TO_NCDHW}, 242 {kOpFormat_C1HWNCoC0, C1HWNCOC0_TO_NCHW}, 243 {kOpFormat_NC1HWC0_C04, NC1HWC04_TO_NCHW}, 244 {kOpFormat_FRACTAL_Z_3D, FRAC_Z3D_TO_NCDHW}}; 245 246 static bool CheckArgs(const FormatArgs &args, int64_t *size); 247 static bool TransShapeToHW_NZ(const ShapeVector &host_shape, ShapeVector *hw_shape); 248 // HOST TO DEVICE 249 static bool NCHW_TO_4D(const FormatArgs &args, void *result); 250 static bool NCHW_TO_FRAC_Z(const FormatArgs &args, void *result); 251 static bool NCHW_TO_NC1HWC0(const FormatArgs &args, void *result); 252 static bool NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result); 253 static bool NCHW_TO_NC1HWC04(const FormatArgs &args, void *result); 254 static bool NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result); 255 static bool NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result); 256 static bool NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result); 257 static bool NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result); 258 static bool NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups); 259 260 // DEVICE TO HOST 261 static bool TO_NCHW(const FormatArgs &args, void *result); 262 static bool FRAC_Z_TO_NCHW(const FormatArgs &args, void *result); 263 static bool FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result); 264 static bool NC1HWC0_TO_NCHW(const FormatArgs &args, void *result); 265 static bool NC1HWC04_TO_NCHW(const FormatArgs &args, void *result); 266 static bool C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result); 267 static bool FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result); 268 static bool NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result); 269 static bool FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups); 270 271 // common check_func 272 static int64_t Common4DCheck(const FormatArgs &args); 273 }; 274 275 /** 276 * Range trans function 277 * */ 278 class BACKEND_EXPORT ShapeRangeTransfer { 279 public: 280 ShapeRangeTransfer() = default; 281 ~ShapeRangeTransfer() = default; 282 RangePair GetRealRange(const RangePair &ori_range, const std::string &format, const TypeId &type, 283 const std::string &padding_str = {""}) const; 284 285 private: 286 static RangePair NHWCRange(const RangePair &ori_range, const TypeId &); 287 static RangePair HWCNRange(const RangePair &ori_range, const TypeId &); 288 static RangePair NC1HWC04Range(const RangePair &ori_range, const TypeId &); 289 static RangePair FRAC_ZC04Range(const RangePair &ori_range, const TypeId &); 290 static RangePair FRAC_ZN_LSTMRange(const RangePair &ori_range, const TypeId &); 291 static RangePair FRAC_ZRange(const RangePair &ori_range, const TypeId &type); 292 static RangePair FRAC_NZRange(const RangePair &ori_range, const TypeId &type); 293 static RangePair NC1HWC0Range(const RangePair &ori_range, const TypeId &type); 294 static RangePair NDC1HWC0Range(const RangePair &ori_range, const TypeId &type); 295 static RangePair C1HWNCOC0Range(const RangePair &ori_range, const TypeId &type); 296 static RangePair FRAC_Z_3DRange(const RangePair &ori_range, const TypeId &type); 297 }; 298 299 /** 300 * If you want extend format, make sure it has a data trans function at host in class 301 * 'FormatTransfer.format_trans_fp_map' 302 * */ 303 static const std::set<std::string> kFormatWithTransFunc = { 304 kOpFormat_HWCN, kOpFormat_NHWC, kOpFormat_FRAC_Z, kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0, 305 kOpFormat_NDC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_3D, kOpFormat_FRACTAL_Z_C04}; 306 307 /** 308 * Interface of datatype trans 309 * */ 310 BACKEND_EXPORT bool TransDataType(const TypeIdArgs &args, void *result); 311 312 /** 313 * Interface of data format trans from host to device 314 * */ 315 BACKEND_EXPORT bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index); 316 317 /** 318 * Interface of data format trans from host to device 319 * */ 320 BACKEND_EXPORT bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups = 1); 321 322 /** 323 * Interface of data format trans from device to host 324 * */ 325 BACKEND_EXPORT bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, 326 size_t index); 327 328 /** 329 * 4D reshape type trans, trans reshape_type from string to int 330 * */ 331 BACKEND_EXPORT void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec); 332 333 /** 334 * 5D reshape type trans, trans reshape_type from string to int 335 * */ 336 BACKEND_EXPORT void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec); 337 338 /** 339 * Get shape after padding 340 * */ 341 BACKEND_EXPORT ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); 342 343 /** 344 * If need padding 345 * */ 346 BACKEND_EXPORT bool IsNeedPadding(const std::string &format, const ShapeVector &shape); 347 348 /** 349 * Padding shape to 5D by default mode 350 * */ 351 template <typename T> 352 std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape, const AnfNodePtr &node = nullptr) { 353 if (shape.size() >= kDim5) { 354 return shape; 355 } 356 std::vector<T> shape_5d(kNcdhw, 1); 357 switch (shape.size()) { 358 case N_ncdhw: 359 return shape_5d; 360 case C_ncdhw: 361 shape_5d[C_ncdhw] = shape[N_ncdhw]; 362 break; 363 case D_ncdhw: 364 shape_5d[C_ncdhw] = shape[N_ncdhw]; 365 shape_5d[D_ncdhw] = shape[C_ncdhw]; 366 break; 367 case H_ncdhw: 368 shape_5d[C_ncdhw] = shape[N_ncdhw]; 369 shape_5d[D_ncdhw] = shape[C_ncdhw]; 370 shape_5d[H_ncdhw] = shape[D_ncdhw]; 371 break; 372 case W_ncdhw: 373 shape_5d[C_ncdhw] = shape[N_ncdhw]; 374 shape_5d[D_ncdhw] = shape[C_ncdhw]; 375 shape_5d[H_ncdhw] = shape[D_ncdhw]; 376 shape_5d[W_ncdhw] = shape[H_ncdhw]; 377 break; 378 default: 379 auto node_info = (node != nullptr) ? ". Node: " + node->fullname_with_scope() : " ."; 380 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected shape :" << shape << node_info; 381 } 382 return shape_5d; 383 } 384 385 /** 386 * Padding shape to 4D by default mode 387 * */ 388 template <typename T> 389 std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape, const AnfNodePtr &node = nullptr) { 390 std::vector<T> shape_4d(kNchwDims, 1); 391 switch (shape.size()) { 392 case kN: 393 return shape_4d; 394 case kC: 395 shape_4d[kC] = shape[kN]; 396 break; 397 case kH: 398 shape_4d[kC] = shape[kN]; 399 shape_4d[kH] = shape[kC]; 400 break; 401 case kW: 402 shape_4d[kC] = shape[kN]; 403 shape_4d[kH] = shape[kC]; 404 shape_4d[kW] = shape[kH]; 405 break; 406 case kNchwDims: 407 return shape; 408 default: 409 auto node_info = (node != nullptr) ? ". Node: " + node->fullname_with_scope() : " ."; 410 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected shape : " << shape << node_info; 411 } 412 return shape_4d; 413 } 414 415 /** 416 * Padding shape to 5D according to reshape type 417 * */ 418 template <typename T> 419 std::vector<T> PaddingShapeTo5d(const std::vector<T> &shape, const std::string &padding_str = {""}) { 420 std::vector<Axis5D> padding_axis; 421 StringToAxisVector5D(padding_str, &padding_axis); 422 if (padding_axis.empty() || shape.size() > padding_axis.size()) { 423 return PaddingShapeTo5dDefault(shape); 424 } 425 std::vector<T> shape_5d(kNcdhw, 1); 426 for (size_t index = 0; index < shape.size(); index++) { 427 shape_5d[padding_axis[index]] = shape[index]; 428 } 429 return shape_5d; 430 } 431 432 /** 433 * Padding shape to 4D according to reshape type 434 * */ 435 template <typename T> 436 std::vector<T> PaddingShapeTo4d(const std::vector<T> &shape, const std::string &padding_str = {""}) { 437 std::vector<Axis> padding_axis; 438 StringToAxisVector4D(padding_str, &padding_axis); 439 if (padding_axis.empty() || shape.size() > padding_axis.size()) { 440 return PaddingShapeTo4dDefault(shape); 441 } 442 std::vector<T> shape_4d(kNchwDims, 1); 443 for (size_t index = 0; index < shape.size(); index++) { 444 shape_4d[padding_axis[index]] = shape[index]; 445 } 446 return shape_4d; 447 } 448 449 /** 450 * Interface of padding shape 451 * */ 452 template <typename T> 453 std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &format, const std::string &pad_index = {""}, 454 const AnfNodePtr &node = nullptr) { 455 if (node != nullptr) { 456 MS_LOG(DEBUG) << "Start padding shape for node: [" << node->fullname_with_scope() << "], format: " << format 457 << ", detail info: " << node->DebugString(); 458 } 459 460 if (IsOneOf3DFormat(format)) { 461 if (shape.size() >= kDim5) { 462 return shape; 463 } 464 if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) { 465 return {-1, -1, -1, -1, -1}; 466 } 467 return PaddingShapeTo5d(shape, pad_index); 468 } 469 470 if (shape.size() == 1 && shape[0] == abstract::Shape::kShapeRankAny) { 471 return {-1, -1, -1, -1}; 472 } 473 return PaddingShapeTo4d(shape, pad_index); 474 } 475 476 /** 477 * Interface of transform pad_index string to AxisVector 478 * */ 479 template <typename T> 480 std::vector<int> StringToAxisVector(const std::vector<T> &shape, const std::string &format, 481 const std::string &pad_index = {""}, const AnfNodePtr &node = nullptr) { 482 if (node != nullptr) { 483 MS_LOG(DEBUG) << "Start transform pad_index to axis_vecor for node: [" << node->fullname_with_scope() 484 << "], format: " << format << ", detail info: " << node->DebugString(); 485 } 486 487 std::vector<int> padding_axis; 488 if (IsOneOf3DFormat(format)) { 489 if (shape.size() >= kDim5) { 490 return padding_axis; 491 } 492 std::vector<Axis5D> padding_axis_5d; 493 StringToAxisVector5D(pad_index, &padding_axis_5d); 494 495 if (padding_axis_5d.empty() || shape.size() != padding_axis_5d.size()) { 496 for (int index = 0; index < static_cast<int>(shape.size()); ++index) { 497 padding_axis.push_back(index); 498 } 499 } else { 500 (void)std::transform(padding_axis_5d.begin(), padding_axis_5d.end(), std::back_inserter(padding_axis), 501 [](Axis5D x) { return static_cast<int>(x); }); 502 } 503 } else { 504 std::vector<Axis> padding_axis_4d; 505 StringToAxisVector4D(pad_index, &padding_axis_4d); 506 507 if (padding_axis_4d.empty() || shape.size() != padding_axis_4d.size()) { 508 for (int index = 0; index < static_cast<int>(shape.size()); ++index) { 509 padding_axis.push_back(index); 510 } 511 } else { 512 (void)std::transform(padding_axis_4d.begin(), padding_axis_4d.end(), std::back_inserter(padding_axis), 513 [](Axis x) { return static_cast<int>(x); }); 514 } 515 } 516 517 return padding_axis; 518 } 519 520 /** 521 * Interface of device shape trance 522 * */ 523 template <typename T> 524 std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node, 525 size_t index, TypeId type, bool is_output = true) { 526 ShapeVector shape_before; 527 (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before), 528 [](T num) { return static_cast<int64_t>(num); }); 529 DeviceShapeTransfer deviceShapeTransfer; 530 auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, node, index, type, is_output); 531 std::vector<T> out_shape; 532 (void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape), 533 [](int64_t num) { return static_cast<T>(num); }); 534 return out_shape; 535 } 536 537 template <typename T> 538 std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, TypeId type, 539 int64_t groups = 1, const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) { 540 ShapeVector shape_before; 541 (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before), 542 [](T num) { return static_cast<int64_t>(num); }); 543 DeviceShapeTransfer deviceShapeTransfer; 544 auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, type, groups, input_hidden_size); 545 std::vector<T> out_shape; 546 (void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape), 547 [](int64_t num) { return static_cast<T>(num); }); 548 return out_shape; 549 } 550 551 struct FormatInfo { FormatInfoFormatInfo552 FormatInfo(std::string format, bool is_padded) : baseFormat(format), isPadded(is_padded) {} 553 std::string baseFormat = kOpFormat_ND; 554 bool isPadded = false; 555 }; 556 557 class BACKEND_EXPORT FormatHelper { 558 public: 559 static FormatHelper &GetInstance() noexcept; 560 const std::string GetBaseFormat(const std::string &format); 561 bool IsBaseFormatType(const std::string &format); 562 bool IsPadded(const std::string &format); 563 void Clear(); 564 565 private: FormatHelper()566 FormatHelper() { InitInfo(); } 567 ~FormatHelper() = default; 568 void InitInfo(); 569 570 std::unordered_map<std::string, FormatInfo> info; 571 }; 572 } // namespace trans 573 } // namespace mindspore 574 575 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MS_DEVICE_SHAPE_TRANSFER_H_ 576