1
2 /**
3 * Copyright 2021-2023 Huawei Technologies Co., Ltd
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 #include "runtime/device/ms_device_shape_transfer.h"
18 #include <functional>
19 #include <unordered_map>
20 #include <numeric>
21 #include <utility>
22 #include <algorithm>
23
24 namespace mindspore {
25 namespace trans {
26 static const ShapeValueDType kShapeDimAny = abstract::Shape::kShapeDimAny;
27
28 const int b1 = 1;
29 const int b2 = 2;
30 const int b4 = 4;
31 const int b8 = 8;
32 const int64_t kCubeSize = 16;
33 const int64_t kCube16 = kCubeSize;
34 const int64_t kCube32 = 32;
35 const int64_t kCube64 = 64;
36 const int64_t kCubeSize_C04 = 4;
37 const int64_t kNiSize = 16;
38 constexpr int kDims2 = 2;
39 constexpr int64_t k4 = 4;
40 static const std::set<TypeId> C0_64 = {kNumberTypeInt4};
41 static const std::set<TypeId> C0_32 = {kNumberTypeUInt8, kNumberTypeInt8};
42 namespace {
43 const size_t hw_h = 1;
44 const size_t hw_w = 2;
45 const size_t fnz_w1 = 4;
46 const size_t fnz_h1 = 3;
47 const size_t fnz_h0 = 2;
48 const size_t fnz_w0 = 1;
49 const size_t fz_n0 = 1;
50 const size_t fz_ni = 2;
51 const size_t fz_c0 = 3;
HasShapeDynamic(const ShapeVector & shape_list)52 bool HasShapeDynamic(const ShapeVector &shape_list) {
53 return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t v) { return v == kShapeDimAny; });
54 }
55
CalMaxShape(int64_t ori_val,int64_t new_val)56 inline int64_t CalMaxShape(int64_t ori_val, int64_t new_val) {
57 if (ori_val < 0) {
58 return kShapeDimAny;
59 }
60
61 return new_val;
62 }
63
64 template <typename T>
Gcd(T a,T b)65 T Gcd(T a, T b) {
66 if (b == 0) {
67 return 0;
68 }
69 T c = b;
70 while (a % b != 0) {
71 c = a % b;
72 a = b;
73 b = c;
74 }
75 return c;
76 }
77
78 template <typename T>
Lcm(T a,T b)79 T Lcm(T a, T b) {
80 if (b == 0) {
81 return 0;
82 }
83 T ret = (a * b) / (Gcd(a, b));
84 return ret;
85 }
86
87 template <typename T>
DivCeil(T n1,T n2)88 T DivCeil(T n1, T n2) {
89 if (n2 != 0) {
90 return (n1 + n2 - 1) / n2;
91 }
92 return 0;
93 }
94
95 template <typename T>
CheckDims(const std::vector<T> & shape)96 bool CheckDims(const std::vector<T> &shape) {
97 if (shape.size() != kDim4) {
98 MS_LOG(ERROR) << "Host shape dims should be 4";
99 return false;
100 }
101 return true;
102 }
103
GetCubeSizeByType(const TypeId & data_type)104 int64_t GetCubeSizeByType(const TypeId &data_type) {
105 if (C0_32.find(data_type) != C0_32.end()) {
106 return kCube32;
107 }
108 if (C0_64.find(data_type) != C0_64.end()) {
109 return kCube64;
110 }
111 return kCube16;
112 }
113
PaddingRangeTo5dDefault(const RangePair & ori_range)114 RangePair PaddingRangeTo5dDefault(const RangePair &ori_range) {
115 RangePair dst_range(kNcdhw, std::pair<int64_t, int64_t>(1, 1));
116 switch (ori_range.size()) {
117 case N_ncdhw:
118 return ori_range;
119 case C_ncdhw:
120 dst_range[C_ncdhw] = ori_range[N_ncdhw];
121 break;
122 case D_ncdhw:
123 dst_range[C_ncdhw] = ori_range[N_ncdhw];
124 dst_range[D_ncdhw] = ori_range[C_ncdhw];
125 break;
126 case H_ncdhw:
127 dst_range[C_ncdhw] = ori_range[N_ncdhw];
128 dst_range[D_ncdhw] = ori_range[C_ncdhw];
129 dst_range[H_ncdhw] = ori_range[D_ncdhw];
130 break;
131 case W_ncdhw:
132 dst_range[C_ncdhw] = ori_range[N_ncdhw];
133 dst_range[D_ncdhw] = ori_range[C_ncdhw];
134 dst_range[H_ncdhw] = ori_range[D_ncdhw];
135 dst_range[W_ncdhw] = ori_range[H_ncdhw];
136 break;
137 default:
138 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected shape size: " << ori_range.size();
139 }
140 return dst_range;
141 }
142
PaddingRangeTo5D(const RangePair & ori_range,const std::string & padding_str={""})143 RangePair PaddingRangeTo5D(const RangePair &ori_range, const std::string &padding_str = {""}) {
144 std::vector<Axis5D> padding_axis;
145 StringToAxisVector5D(padding_str, &padding_axis);
146 if (padding_axis.empty() || ori_range.size() > padding_axis.size()) {
147 return PaddingRangeTo5dDefault(ori_range);
148 }
149
150 RangePair dst_range(kNcdhw, std::pair<int64_t, int64_t>(1, 1));
151 for (size_t index = 0; index < ori_range.size(); index++) {
152 dst_range[padding_axis[index]] = ori_range[index];
153 }
154 return dst_range;
155 }
156
PaddingRangeTo4dDefault(const RangePair & ori_range)157 RangePair PaddingRangeTo4dDefault(const RangePair &ori_range) {
158 RangePair dst_range(kNchwDims, std::pair<int64_t, int64_t>(1, 1));
159 switch (ori_range.size()) {
160 case kN:
161 return dst_range;
162 case kC:
163 dst_range[kC] = ori_range[kN];
164 break;
165 case kH:
166 dst_range[kC] = ori_range[kN];
167 dst_range[kH] = ori_range[kC];
168 break;
169 case kW:
170 dst_range[kC] = ori_range[kN];
171 dst_range[kH] = ori_range[kC];
172 dst_range[kW] = ori_range[kH];
173 break;
174 case kNchwDims:
175 return ori_range;
176 default:
177 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected range size: " << ori_range.size();
178 }
179 return dst_range;
180 }
181
PaddingRangeTo4D(const RangePair & ori_range,const std::string & padding_str={""})182 RangePair PaddingRangeTo4D(const RangePair &ori_range, const std::string &padding_str = {""}) {
183 std::vector<Axis> padding_axis;
184 StringToAxisVector4D(padding_str, &padding_axis);
185 if (padding_axis.empty() || ori_range.size() > padding_axis.size()) {
186 return PaddingRangeTo4dDefault(ori_range);
187 }
188
189 RangePair dst_range(kNchwDims, std::pair<int64_t, int64_t>(1, 1));
190 for (size_t index = 0; index < ori_range.size(); index++) {
191 dst_range[padding_axis[index]] = ori_range[index];
192 }
193 return dst_range;
194 }
195 } // namespace
196
StringToAxisVector4D(const std::string & reshape_type_str,std::vector<Axis> * reshape_type_vec)197 void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
198 MS_EXCEPTION_IF_NULL(reshape_type_vec);
199 if (reshape_type_str.empty()) {
200 MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
201 return;
202 }
203 for (const auto &c : reshape_type_str) {
204 switch (c) {
205 case 'N':
206 reshape_type_vec->push_back(N);
207 break;
208 case 'C':
209 reshape_type_vec->push_back(C);
210 break;
211 case 'H':
212 reshape_type_vec->push_back(H);
213 break;
214 case 'W':
215 reshape_type_vec->push_back(W);
216 break;
217 default:
218 MS_LOG(INTERNAL_EXCEPTION) << "Unknown axis " << c << "in reshape type.";
219 }
220 }
221 }
222
StringToAxisVector5D(const std::string & reshape_type_str,std::vector<Axis5D> * reshape_type_vec)223 void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
224 MS_EXCEPTION_IF_NULL(reshape_type_vec);
225 if (reshape_type_str.empty()) {
226 MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
227 return;
228 }
229 for (const auto &c : reshape_type_str) {
230 switch (c) {
231 case 'N':
232 reshape_type_vec->push_back(N_ncdhw);
233 break;
234 case 'C':
235 reshape_type_vec->push_back(C_ncdhw);
236 break;
237 case 'D':
238 reshape_type_vec->push_back(D_ncdhw);
239 break;
240 case 'H':
241 reshape_type_vec->push_back(H_ncdhw);
242 break;
243 case 'W':
244 reshape_type_vec->push_back(W_ncdhw);
245 break;
246 default:
247 MS_LOG(INTERNAL_EXCEPTION) << "Unknown axis " << c << "in reshape type.";
248 }
249 }
250 }
251
IsNeedPadding(const std::string & format,const ShapeVector & shape)252 bool IsNeedPadding(const std::string &format, const ShapeVector &shape) {
253 if (shape.empty()) {
254 return false;
255 }
256 if (IsDynamicRank(shape) && !IsOneOfDynRankNeedPadShape(format)) {
257 return false;
258 }
259 if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW || IsOneOfNoPaddingFormat(format)) {
260 return false;
261 } else if (shape.size() < kDim4) {
262 return true;
263 }
264 return false;
265 }
266
GetRuntimePaddingShape(const AnfNodePtr & node,size_t index)267 ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
268 MS_EXCEPTION_IF_NULL(node);
269 ShapeVector host_shape;
270 if (node->isa<ValueNode>()) {
271 auto value_node = node->cast<ValueNodePtr>();
272 MS_EXCEPTION_IF_NULL(value_node);
273 auto node_value = value_node->value();
274 MS_EXCEPTION_IF_NULL(node_value);
275 // Scalar has no shape.
276 if (node_value->isa<Scalar>()) {
277 return {};
278 }
279 if (node_value->isa<StringImm>()) {
280 auto string_value = node_value->cast<StringImmPtr>();
281 MS_EXCEPTION_IF_NULL(string_value);
282 return {SizeToLong(string_value->ToString().size())};
283 }
284 if (node_value->isa<ValueSequence>()) {
285 MS_LOG(INFO) << "GetRuntimePaddingShape does not support the value sequence for value node:"
286 << node->fullname_with_scope() << ", debug name:" << node->DebugString();
287 return {0};
288 }
289 auto tensor = node_value->cast<tensor::TensorPtr>();
290 if (tensor == nullptr) {
291 MS_LOG(INTERNAL_EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
292 }
293 host_shape = tensor->shape();
294 } else {
295 host_shape = common::AnfAlgo::GetOutputInferShape(node, index);
296 }
297 auto format = AnfAlgo::GetOutputFormat(node, index);
298 if (IsNeedPadding(format, host_shape)) {
299 host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index), node);
300 }
301 return host_shape;
302 }
303
TransDataType(const TypeIdArgs & args,void * result)304 bool TransDataType(const TypeIdArgs &args, void *result) {
305 DataTypeTransfer dataTypeTransfer;
306 return dataTypeTransfer.TransDataType(args, result);
307 }
308
TransFormat(const FormatArgs & args,void * result,const AnfNodePtr & node,size_t index)309 bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index) {
310 FormatTransfer formatTransfer;
311 return formatTransfer.TransDataByFormat(args, result, node, index, true);
312 }
313
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,int64_t groups)314 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
315 FormatTransfer formatTransfer;
316 return formatTransfer.TransDataBackwordCore(args, result, groups);
317 }
318
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,const AnfNodePtr & node,size_t index)319 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index) {
320 FormatTransfer formatTransfer;
321 return formatTransfer.TransDataByFormat(args, result, node, index, false);
322 }
323
324 /**###################### DATA TYPE TRANS ################################*/
CheckMemSize(const TypeIdArgs & args)325 void CheckMemSize(const TypeIdArgs &args) {
326 auto src_type_size = abstract::TypeIdSize(args.src_data_type);
327 auto dst_type_size = abstract::TypeIdSize(args.dst_data_type);
328 if (src_type_size < 1 || dst_type_size < 1) {
329 MS_LOG(INTERNAL_EXCEPTION) << "Invalid src or dst data type. Src type: " << TypeIdLabel(args.src_data_type)
330 << ", dst type: " << TypeIdLabel(args.dst_data_type);
331 }
332 if (SizeToLong(args.data_size / src_type_size) != args.src_shape_size) {
333 MS_LOG(INTERNAL_EXCEPTION) << "Invalid src or dst data shape size. Src shape size: " << args.src_shape_size
334 << ", dst shape size: " << args.data_size / src_type_size;
335 }
336 }
337
338 template <typename SrcT, typename DstT>
TransDataSrc2Dst(const TypeIdArgs & args,void * dst,const int64_t data_size)339 void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const int64_t data_size) {
340 CheckMemSize(args);
341 for (int64_t idx = 0; idx != data_size; idx++) {
342 SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
343 static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
344 }
345 }
346 template <typename SrcT>
TransDataSrc2Fp16(const TypeIdArgs & args,void * dst,const int64_t data_size)347 void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const int64_t data_size) {
348 CheckMemSize(args);
349 auto src_data = static_cast<const SrcT *>(args.data);
350 auto half_data = static_cast<float16 *>(dst);
351 for (int64_t i = 0; i < data_size; i++) {
352 half_data[i] = float16(src_data[i]);
353 }
354 }
355
CastKernel(const TypeIdArgs & args,void * dst,int64_t data_size,DataTypeTransMode mode) const356 bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const {
357 using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const int64_t)>;
358 const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
359 {DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
360 {DataTypeTransMode::FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
361 {DataTypeTransMode::FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
362 {DataTypeTransMode::FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
363 {DataTypeTransMode::FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
364 {DataTypeTransMode::FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
365 {DataTypeTransMode::FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
366 {DataTypeTransMode::FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
367 {DataTypeTransMode::FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
368 {DataTypeTransMode::FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
369 {DataTypeTransMode::FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
370 {DataTypeTransMode::FROM_INT16_TO_INT32, TransDataSrc2Dst<int16_t, int32_t>},
371 {DataTypeTransMode::FROM_INT16_TO_INT64, TransDataSrc2Dst<int16_t, int64_t>},
372 {DataTypeTransMode::FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
373 {DataTypeTransMode::FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
374 {DataTypeTransMode::FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
375 {DataTypeTransMode::FROM_INT32_TO_INT16, TransDataSrc2Dst<int32_t, int16_t>},
376 {DataTypeTransMode::FROM_INT32_TO_UINT16, TransDataSrc2Dst<int32_t, uint16_t>},
377 {DataTypeTransMode::FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
378 {DataTypeTransMode::FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
379 {DataTypeTransMode::FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
380 {DataTypeTransMode::FROM_INT64_TO_INT16, TransDataSrc2Dst<int64_t, int16_t>},
381 {DataTypeTransMode::FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
382 {DataTypeTransMode::FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
383 {DataTypeTransMode::FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
384 {DataTypeTransMode::FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
385 {DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
386 {DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>}};
387
388 if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
389 device::FloatToHalf(dst, args.data, LongToSize(data_size));
390 return true;
391 } else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
392 device::HalfToFloat(dst, args.data, LongToSize(data_size));
393 return true;
394 }
395 auto iter = cast_kernel_map.find(mode);
396 if (iter != cast_kernel_map.end()) {
397 iter->second(args, dst, data_size);
398 return true;
399 } else {
400 MS_LOG(ERROR) << "Can not find a datatype trans function. Src type :" << TypeIdLabel(args.src_data_type)
401 << ", dst_type:" << TypeIdLabel(args.dst_data_type);
402 return false;
403 }
404 }
405
TransDataType(const TypeIdArgs & args,void * result) const406 bool DataTypeTransfer::TransDataType(const TypeIdArgs &args, void *result) const {
407 MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_data_type) << " to "
408 << TypeIdLabel(args.dst_data_type);
409 MS_EXCEPTION_IF_NULL(result);
410 std::pair<TypeId, TypeId> type_info(args.src_data_type, args.dst_data_type);
411 auto iter = mode_map.find(type_info);
412 if (iter == mode_map.end()) {
413 MS_LOG(ERROR) << "Can not find a datatype trans type. src_type :" << TypeIdLabel(args.src_data_type)
414 << ", dst_type:" << TypeIdLabel(args.dst_data_type);
415 return false;
416 }
417 auto trans_mode = iter->second;
418 if (!CastKernel(args, result, args.src_shape_size, trans_mode)) {
419 MS_LOG(ERROR) << "Failed to trans datatype. Src: " << TypeIdLabel(args.src_data_type)
420 << ", dst: " << TypeIdLabel(args.dst_data_type);
421 return false;
422 }
423 return true;
424 }
425
426 /**###################### DATA SHAPE TRANS ################################*/
GetDeviceShapeByFormat(const ShapeVector & shape,const std::string & format,const AnfNodePtr & node,size_t index,const TypeId & type,bool is_output) const427 ShapeVector DeviceShapeTransfer::GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format,
428 const AnfNodePtr &node, size_t index, const TypeId &type,
429 bool is_output) const {
430 auto dev_shape = GetFixedDeviceShape(shape, node, index, is_output);
431 if (dev_shape.has_value()) {
432 return dev_shape.value();
433 }
434 int64_t groups = 1;
435 if (format == kOpFormat_FRAC_Z) {
436 groups = common::AnfAlgo::GetAttrGroups(node, index);
437 }
438 ShapeVector input_hidden_size = {kAlign16, kAlign16};
439 if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
440 input_hidden_size = GetAttrInputAndHiddenSize(node);
441 }
442 if (node != nullptr) {
443 MS_LOG(DEBUG) << "Start trans infer shape to device shape for node: " << node->DebugString()
444 << ", format: " << format;
445 }
446 return TransCore(shape, format, type, groups, input_hidden_size);
447 }
448
GetDeviceShapeByFormat(const ShapeVector & shape,const std::string & format,const TypeId & type,int64_t groups,const ShapeVector & input_hidden_size) const449 ShapeVector DeviceShapeTransfer::GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format,
450 const TypeId &type, int64_t groups,
451 const ShapeVector &input_hidden_size) const {
452 return TransCore(shape, format, type, groups, input_hidden_size);
453 }
454
GetFixedDeviceShape(const ShapeVector &,const AnfNodePtr & node,size_t index,bool is_output) const455 std::optional<ShapeVector> DeviceShapeTransfer::GetFixedDeviceShape(const ShapeVector &, const AnfNodePtr &node,
456 size_t index, bool is_output) const {
457 if (node == nullptr || !node->isa<CNode>()) {
458 return {};
459 }
460 auto attr_name = is_output ? kAttrFixedOutputDeviceShape : kAttrFixedInputDeviceShape;
461 auto cnode = node->cast<CNodePtr>();
462 if (!common::AnfAlgo::HasNodeAttr(attr_name, cnode)) {
463 return {};
464 }
465
466 auto shapes = common::AnfAlgo::GetNodeAttr<std::vector<ShapeVector>>(cnode, attr_name);
467 if (index >= shapes.size()) {
468 MS_LOG(INFO) << "Index is out of range, got index: " << index << ", shape size: " << shapes.size();
469 return {};
470 }
471 return std::optional<ShapeVector>(std::move(shapes[index]));
472 }
473
TransCore(const ShapeVector & shape,const std::string & format,const TypeId & type,int64_t groups,const ShapeVector & input_hidden_size) const474 ShapeVector DeviceShapeTransfer::TransCore(const ShapeVector &shape, const std::string &format, const TypeId &type,
475 int64_t groups, const ShapeVector &input_hidden_size) const {
476 using DeviceShapeTransferFunc = std::function<ShapeVector(const ShapeVector &, const TypeId &)>;
477 static const mindspore::HashMap<std::string, DeviceShapeTransferFunc> device_shape_map = {
478 {kOpFormat_NCHW, NCHWDeviceShape},
479 {kOpFormat_NHWC, NHWCDeviceShape},
480 {kOpFormat_HWCN, HWCNDeviceShape},
481 {kOpFormat_NCDHW, NCDHWDeviceShape},
482 {kOpFormat_FRAC_Z, FRAC_ZDeviceShape},
483 {kOpFormat_FRAC_NZ, FRAC_NZDeviceShape},
484 {kOpFormat_NC1HWC0, NC1HWC0DeviceShape},
485 {kOpFormat_NDC1HWC0, NDC1HWC0DeviceShape},
486 {kOpFormat_C1HWNCoC0, C1HWNCOC0DeviceShape},
487 {kOpFormat_NC1HWC0_C04, NC1HWC04DeviceShape},
488 {kOpFormat_FRACTAL_Z_3D, FRAC_Z3DDeviceShape},
489 {kOpFormat_FRACTAL_Z_C04, FRAC_ZC04DeviceShape},
490 {kOpFormat_ChannelLast, ChannelLastDeviceShape},
491 {kOpFormat_FRACTAL_ZN_LSTM, FRAC_ZN_LSTMDeviceShape}};
492 if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
493 return shape;
494 }
495 if (groups > 1 && format == kOpFormat_FRAC_Z) {
496 return FRAC_ZDeviceShapeWithGroups(shape, type, groups);
497 }
498 if (format == kOpFormat_FRACTAL_ZN_RNN) {
499 return FRAC_ZN_RNNDeviceShape(shape, type, input_hidden_size);
500 }
501 if (format == kOpFormat_ND_RNN_BIAS) {
502 return NDRNNBiasDeviceShape(shape, type, input_hidden_size[1]);
503 }
504 auto temp_shape = shape;
505 if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && shape.size() < kDim4 &&
506 !IsOneOf3DFormat(format)) {
507 MS_LOG(INFO) << "Origin shape size is less than 4, should be Padding shape by Default firstly";
508 temp_shape = PaddingShapeTo4dDefault(shape);
509 }
510 if (shape.size() != kDim5 && IsOneOf3DFormat(format)) {
511 temp_shape = PaddingShapeTo5dDefault(shape);
512 }
513 auto iter = device_shape_map.find(format);
514 if (iter == device_shape_map.end()) {
515 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected format[" << format << "]";
516 }
517 return iter->second(temp_shape, type);
518 }
519
NCHWDeviceShape(const ShapeVector & shape,const TypeId &)520 ShapeVector DeviceShapeTransfer::NCHWDeviceShape(const ShapeVector &shape, const TypeId &) {
521 if (!CheckDims(shape)) {
522 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
523 }
524 return shape;
525 }
526
NHWCDeviceShape(const ShapeVector & shape,const TypeId &)527 ShapeVector DeviceShapeTransfer::NHWCDeviceShape(const ShapeVector &shape, const TypeId &) {
528 if (!CheckDims(shape)) {
529 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
530 }
531 ShapeVector device_shape;
532 device_shape.push_back(shape[kN]);
533 device_shape.push_back(shape[kH]);
534 device_shape.push_back(shape[kW]);
535 device_shape.push_back(shape[kC]);
536 return device_shape;
537 }
538
HWCNDeviceShape(const ShapeVector & shape,const TypeId &)539 ShapeVector DeviceShapeTransfer::HWCNDeviceShape(const ShapeVector &shape, const TypeId &) {
540 if (!CheckDims(shape)) {
541 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
542 }
543 ShapeVector device_shape;
544 device_shape.push_back(shape[kH]);
545 device_shape.push_back(shape[kW]);
546 device_shape.push_back(shape[kC]);
547 device_shape.push_back(shape[kN]);
548 return device_shape;
549 }
550
FRAC_ZDeviceShape(const ShapeVector & shape,const TypeId & type)551 ShapeVector DeviceShapeTransfer::FRAC_ZDeviceShape(const ShapeVector &shape, const TypeId &type) {
552 if (!CheckDims(shape)) {
553 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
554 }
555 ShapeVector device_shape;
556 auto c0 = GetCubeSizeByType(type);
557 if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) {
558 device_shape.push_back(abstract::Shape::kShapeDimAny);
559 } else {
560 auto c1 = (shape[kC] + c0 - 1) / c0;
561 device_shape.push_back(shape[kH] * shape[kW] * c1);
562 }
563 if (shape[kN] == abstract::Shape::kShapeDimAny) {
564 device_shape.push_back(abstract::Shape::kShapeDimAny);
565 } else {
566 auto no = (shape[kN] + kNiSize - 1) / kNiSize;
567 device_shape.push_back(no);
568 }
569 device_shape.push_back(kNiSize);
570 device_shape.push_back(c0);
571 return device_shape;
572 }
573
NC1HWC0DeviceShape(const ShapeVector & shape,const TypeId & type)574 ShapeVector DeviceShapeTransfer::NC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
575 if (!CheckDims(shape)) {
576 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
577 }
578 ShapeVector device_shape;
579 auto c0 = GetCubeSizeByType(type);
580 auto c1 = (shape[kC] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (shape[kC] + c0 - 1) / c0;
581 device_shape.push_back(shape[kN]);
582 device_shape.push_back(c1);
583 device_shape.push_back(shape[kH]);
584 device_shape.push_back(shape[kW]);
585 device_shape.push_back(c0);
586 return device_shape;
587 }
588
NDC1HWC0DeviceShape(const ShapeVector & shape,const TypeId & type)589 ShapeVector DeviceShapeTransfer::NDC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
590 if (shape.size() == kDim6) {
591 return shape;
592 }
593 if (shape.size() != kDim5) {
594 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
595 }
596 ShapeVector device_shape;
597 auto c0 = GetCubeSizeByType(type);
598 auto c1 = (shape[1] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (shape[1] + c0 - 1) / c0;
599 device_shape.push_back(shape[N_ncdhw]);
600 device_shape.push_back(shape[D_ncdhw]);
601 device_shape.push_back(c1);
602 device_shape.push_back(shape[H_ncdhw]);
603 device_shape.push_back(shape[W_ncdhw]);
604 device_shape.push_back(c0);
605 return device_shape;
606 }
607
FRAC_Z3DDeviceShape(const ShapeVector & shape,const TypeId & type)608 ShapeVector DeviceShapeTransfer::FRAC_Z3DDeviceShape(const ShapeVector &shape, const TypeId &type) {
609 if (shape.size() != kDim5) {
610 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
611 }
612 ShapeVector device_shape;
613 auto c0 = GetCubeSizeByType(type);
614 if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
615 device_shape.push_back(abstract::Shape::kShapeDimAny);
616 } else {
617 auto c1 = (shape[1] + c0 - 1) / c0;
618 device_shape.push_back(shape[D_ncdhw] * c1 * shape[H_ncdhw] * shape[W_ncdhw]);
619 }
620 auto no =
621 (shape[0] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (shape[0] + kNiSize - 1) / kNiSize;
622 device_shape.push_back(no);
623 device_shape.push_back(kNiSize);
624 device_shape.push_back(c0);
625 return device_shape;
626 }
627
C1HWNCOC0DeviceShape(const ShapeVector & shape,const TypeId & type)628 ShapeVector DeviceShapeTransfer::C1HWNCOC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
629 if (!CheckDims(shape)) {
630 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
631 }
632 ShapeVector device_shape;
633 auto c0 = GetCubeSizeByType(type);
634 if (shape[kC] == abstract::Shape::kShapeDimAny) {
635 device_shape.push_back(abstract::Shape::kShapeDimAny);
636 } else {
637 device_shape.push_back((shape[kC] - 1) / c0 + 1);
638 }
639 device_shape.push_back(shape[kH]);
640 device_shape.push_back(shape[kW]);
641 device_shape.push_back(shape[kN]);
642 device_shape.push_back(c0);
643 device_shape.push_back(c0);
644 return device_shape;
645 }
646
FRAC_ZC04DeviceShape(const ShapeVector & shape,const TypeId &)647 ShapeVector DeviceShapeTransfer::FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &) {
648 if (!CheckDims(shape)) {
649 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
650 }
651 ShapeVector device_shape;
652 const int64_t C04 = 4;
653 int64_t first_dim;
654 if (HasShapeDynamic({shape[kH], shape[kW]})) {
655 first_dim = abstract::Shape::kShapeDimAny;
656 } else {
657 first_dim = DivCeil(C04 * shape[kH] * shape[kW], kCubeSize);
658 }
659 auto no =
660 (shape[kN] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : DivCeil(shape.at(kN), kCubeSize);
661 device_shape.push_back(first_dim);
662 device_shape.push_back(no);
663 device_shape.push_back(kCubeSize);
664 device_shape.push_back(kCubeSize);
665 return device_shape;
666 }
667
NC1HWC04DeviceShape(const ShapeVector & shape,const TypeId &)668 ShapeVector DeviceShapeTransfer::NC1HWC04DeviceShape(const ShapeVector &shape, const TypeId &) {
669 if (!CheckDims(shape)) {
670 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
671 }
672 ShapeVector device_shape;
673 const int64_t C04 = 4;
674 const int64_t C1 =
675 (shape[kC] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : DivCeil(shape.at(kC), C04);
676 device_shape.push_back(shape[kN]);
677 device_shape.push_back(C1);
678 device_shape.push_back(shape[kH]);
679 device_shape.push_back(shape[kW]);
680 device_shape.push_back(C04);
681 return device_shape;
682 }
683
NCDHWDeviceShape(const ShapeVector & shape,const TypeId &)684 ShapeVector DeviceShapeTransfer::NCDHWDeviceShape(const ShapeVector &shape, const TypeId &) {
685 if (shape.size() < kDim5) {
686 MS_LOG(INTERNAL_EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
687 }
688 return shape;
689 }
690
ChannelLastDeviceShape(const ShapeVector & shape,const TypeId &)691 ShapeVector DeviceShapeTransfer::ChannelLastDeviceShape(const ShapeVector &shape, const TypeId &) {
692 auto dim = shape.size();
693 ShapeVector axis;
694 axis.resize(dim);
695 const int step_value = 2;
696 std::iota(axis.begin() + 1, axis.end(), step_value);
697 axis[dim - 1] = 1;
698 ShapeVector device_shape;
699 (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
700 [&shape](size_t n) { return shape[n]; });
701 return device_shape;
702 }
703
FRAC_NZDeviceShape(const ShapeVector & shape,const TypeId & type)704 ShapeVector DeviceShapeTransfer::FRAC_NZDeviceShape(const ShapeVector &shape, const TypeId &type) {
705 ShapeVector device_shape;
706 auto c0 = GetCubeSizeByType(type);
707 if (shape.size() == 1 && (shape[0] == 1 || shape[0] % c0 == 0)) {
708 // For [1] and [1024] shape we can trait it as NZ shape
709 return shape;
710 }
711 if (shape.size() == 1) {
712 device_shape.push_back(DivCeil(shape[0], c0));
713 device_shape.push_back(1);
714 device_shape.push_back(kCubeSize);
715 device_shape.push_back(c0);
716 return device_shape;
717 } else {
718 const auto remove_dim = 2;
719 (void)std::copy(shape.begin(), shape.end() - remove_dim, std::back_inserter(device_shape));
720 }
721 int64_t h_shape = shape[shape.size() - kH];
722 int64_t w_shape = shape[shape.size() - 1];
723 int64_t w1 = (w_shape == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (w_shape - 1) / c0 + 1;
724 int64_t h1 =
725 (h_shape == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (h_shape - 1) / kCubeSize + 1;
726 device_shape.push_back(w1);
727 device_shape.push_back(h1);
728 device_shape.push_back(kCubeSize);
729 device_shape.push_back(c0);
730 return device_shape;
731 }
732
FRAC_ZN_LSTMDeviceShape(const ShapeVector & shape,const TypeId &)733 ShapeVector DeviceShapeTransfer::FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &) {
734 ShapeVector device_shape;
735 const int64_t lstm_ni = 4;
736 const int64_t ni = 16;
737 int64_t first = abstract::Shape::kShapeDimAny;
738 int64_t second = abstract::Shape::kShapeDimAny;
739 if (!HasShapeDynamic({shape[kN], shape[kC]})) {
740 const int64_t h = shape.at(kN) / lstm_ni;
741 const int64_t i = shape.at(kC) - h;
742 first = DivCeil(i, ni) + DivCeil(h, ni);
743 second = lstm_ni * DivCeil(h, ni);
744 }
745 device_shape.push_back(first);
746 device_shape.push_back(second);
747 device_shape.push_back(ni);
748 device_shape.push_back(ni);
749 return device_shape;
750 }
751
FRAC_ZDeviceShapeWithGroups(const ShapeVector & shape,const TypeId & type,int64_t groups)752 ShapeVector DeviceShapeTransfer::FRAC_ZDeviceShapeWithGroups(const ShapeVector &shape, const TypeId &type,
753 int64_t groups) {
754 if (!CheckDims(shape)) {
755 MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
756 }
757 if (groups <= 0) {
758 MS_LOG(INTERNAL_EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
759 }
760 auto cube_size = GetCubeSizeByType(type);
761 auto c1_dim = abstract::Shape::kShapeDimAny;
762 auto g_dim = abstract::Shape::kShapeDimAny;
763 auto n1 = abstract::Shape::kShapeDimAny;
764 if (shape.size() < kShape2dDims) {
765 MS_LOG(INTERNAL_EXCEPTION) << "Format FRAC_ZDeviceShape don't support shape with " << shape.size() << " dims";
766 }
767 if (!HasShapeDynamic({shape[kC], shape[kN]})) {
768 auto group_size = groups;
769 auto cin_ori_tmp = static_cast<int64_t>(shape[kC]);
770 auto cout_ori_tmp = static_cast<int64_t>(shape[kN]) / group_size;
771 auto e_mult =
772 std::min(Lcm(Lcm(cin_ori_tmp, cube_size) / cin_ori_tmp, Lcm(cout_ori_tmp, cube_size) / cout_ori_tmp), group_size);
773 auto cin_opt = DivCeil(e_mult * cin_ori_tmp, cube_size) * cube_size;
774 c1_dim = cin_opt / cube_size;
775 g_dim = DivCeil(group_size, e_mult);
776 n1 = DivCeil(cout_ori_tmp * e_mult, cube_size);
777 }
778 ShapeVector device_shape;
779 if (!HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) {
780 device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
781 } else {
782 device_shape.push_back(abstract::Shape::kShapeDimAny);
783 }
784 device_shape.push_back(n1);
785 device_shape.push_back(kNiSize);
786 device_shape.push_back(cube_size);
787 return device_shape;
788 }
789
FRAC_ZN_RNNDeviceShape(const ShapeVector & shape,const TypeId & type,const ShapeVector & input_hidden_size)790 ShapeVector DeviceShapeTransfer::FRAC_ZN_RNNDeviceShape(const ShapeVector &shape, const TypeId &type,
791 const ShapeVector &input_hidden_size) {
792 if (shape.size() < kShape2dDims) {
793 MS_LOG(INTERNAL_EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
794 }
795 auto C0 = GetCubeSizeByType(type);
796 auto input_size = input_hidden_size[0];
797 auto hidden_size = input_hidden_size[1];
798 auto dim_last1 = shape[shape.size() - 1];
799 auto dim_last2 = shape[shape.size() - kDim2];
800 const int64_t NUM16 = 16;
801
802 ShapeVector device_shape = shape;
803 if (dim_last2 == abstract::Shape::kShapeDimAny) {
804 device_shape[shape.size() - kDim2] = abstract::Shape::kShapeDimAny;
805 } else if (dim_last2 == input_size || dim_last2 == hidden_size) {
806 device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
807 } else if (dim_last2 == input_size + hidden_size) {
808 device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
809 } else {
810 MS_LOG(INTERNAL_EXCEPTION) << "The second-last dim value of shape is invalid.";
811 }
812 if (dim_last1 == abstract::Shape::kShapeDimAny) {
813 device_shape[shape.size() - kDim1] = abstract::Shape::kShapeDimAny;
814 } else {
815 if (dim_last1 % hidden_size != 0) {
816 MS_LOG(INTERNAL_EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size "
817 << hidden_size;
818 }
819 int64_t n_num = shape[shape.size() - 1] / hidden_size;
820 device_shape[shape.size() - kDim1] = n_num * DivCeil(hidden_size, C0);
821 }
822 device_shape.push_back(NUM16);
823 device_shape.push_back(C0);
824 return device_shape;
825 }
826
NDRNNBiasDeviceShape(const ShapeVector & shape,const TypeId & type,int64_t hidden_size)827 ShapeVector DeviceShapeTransfer::NDRNNBiasDeviceShape(const ShapeVector &shape, const TypeId &type,
828 int64_t hidden_size) {
829 if (shape.empty()) {
830 MS_LOG(INTERNAL_EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
831 }
832 auto C0 = GetCubeSizeByType(type);
833 ShapeVector device_shape = shape;
834 // cppcheck-suppress *
835 auto dim_last1 = shape[shape.size() - 1];
836 if (dim_last1 == abstract::Shape::kShapeDimAny) {
837 device_shape[shape.size() - 1] = abstract::Shape::kShapeDimAny;
838 } else {
839 if (hidden_size <= 0 || dim_last1 % hidden_size != 0) {
840 MS_LOG(INTERNAL_EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size "
841 << hidden_size;
842 }
843 int64_t n_num = shape[shape.size() - 1] / hidden_size;
844 device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
845 }
846 return device_shape;
847 }
848
GetAttrInputAndHiddenSize(const AnfNodePtr & node) const849 ShapeVector DeviceShapeTransfer::GetAttrInputAndHiddenSize(const AnfNodePtr &node) const {
850 MS_EXCEPTION_IF_NULL(node);
851 std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
852 if (!node->isa<CNode>() && !node->isa<Parameter>()) {
853 return input_hidden_size;
854 }
855
856 if (node->isa<Parameter>()) {
857 auto param = node->cast<ParameterPtr>();
858 input_hidden_size[0] = param->input_size();
859 input_hidden_size[1] = param->hidden_size();
860 } else {
861 CNodePtr cnode = node->cast<CNodePtr>();
862 if (cnode == nullptr || !common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) ||
863 !common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
864 MS_LOG(INTERNAL_EXCEPTION)
865 << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
866 << node->DebugString();
867 }
868 input_hidden_size[0] = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrInputSize);
869 input_hidden_size[1] = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrHiddenSize);
870 }
871 return input_hidden_size;
872 }
873
874 /**###################### DATA FORMAT TRANS ################################*/
SetData(int64_t size,bool pad_zero,int64_t src_idx,int64_t dst_idx,const FormatArgs & args,void * result)875 inline void SetData(int64_t size, bool pad_zero, int64_t src_idx, int64_t dst_idx, const FormatArgs &args,
876 void *result) {
877 switch (size) {
878 case b1:
879 static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
880 break;
881 case b2:
882 static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
883 break;
884 case b4:
885 static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
886 break;
887 case b8:
888 static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
889 break;
890 default:
891 MS_LOG(EXCEPTION) << "Trans data not support size " << size;
892 }
893 }
894
TransDataByFormat(const FormatArgs & args,void * result,const AnfNodePtr & node,size_t index,bool is_forward)895 bool FormatTransfer::TransDataByFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index,
896 bool is_forward) {
897 int64_t groups = 1;
898 if (args.device_format == kOpFormat_FRAC_Z && node != nullptr) {
899 groups = common::AnfAlgo::GetAttrGroups(node, index);
900 }
901 if (is_forward) {
902 return TransDataForwardCore(args, result, groups);
903 }
904 return TransDataBackwordCore(args, result, groups);
905 }
906
TransDataForwardCore(const FormatArgs & args,void * result,int64_t groups)907 bool FormatTransfer::TransDataForwardCore(const FormatArgs &args, void *result, int64_t groups) {
908 MS_LOG(DEBUG) << "Start trans format.";
909 if (abstract::TypeIdSize(args.src_data_type) < 1) {
910 MS_LOG(ERROR) << "Invalid datatype: " << args.src_data_type;
911 return false;
912 }
913 if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
914 return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, true, groups);
915 }
916 auto iter = format_trans_fp_map.find(args.device_format);
917 if (iter == format_trans_fp_map.end()) {
918 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected format[" << args.device_format << "]";
919 }
920 return iter->second(args, result);
921 }
922
TransDataBackwordCore(const FormatArgs & args,void * result,int64_t groups)923 bool FormatTransfer::TransDataBackwordCore(const FormatArgs &args, void *result, int64_t groups) {
924 MS_LOG(DEBUG) << "Start trans format.";
925 if (abstract::TypeIdSize(args.src_data_type) < 1) {
926 MS_LOG(ERROR) << "Invalid datatype, type: " << args.src_data_type;
927 return false;
928 }
929 if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
930 return FRAC_Z_TO_NCHW_WITH_GROUPS(args, result, groups);
931 }
932 auto iter = format_trans_bp_map.find(args.device_format);
933 if (iter == format_trans_bp_map.end()) {
934 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected format[" << args.device_format << "]";
935 }
936 return iter->second(args, result);
937 }
938
CheckArgs(const FormatArgs & args,int64_t * size)939 bool FormatTransfer::CheckArgs(const FormatArgs &args, int64_t *size) {
940 if (args.host_shape.size() != kDim4) {
941 MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
942 return false;
943 }
944 MS_EXCEPTION_IF_NULL(size);
945 *size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
946 if (*size < 1) {
947 MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
948 return false;
949 }
950 auto total_size = abstract::ShapeSize(args.device_shape) * (*size);
951 if (total_size != SizeToLong(args.device_size)) {
952 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
953 return false;
954 }
955 return true;
956 }
957
TransShapeToHW_NZ(const ShapeVector & host_shape,ShapeVector * hw_shape)958 bool FormatTransfer::TransShapeToHW_NZ(const ShapeVector &host_shape, ShapeVector *hw_shape) {
959 MS_EXCEPTION_IF_NULL(hw_shape);
960 if (host_shape.empty()) {
961 MS_LOG(ERROR) << "Size of vector is 0.";
962 return false;
963 }
964 switch (host_shape.size()) {
965 case 1:
966 hw_shape->push_back(1);
967 hw_shape->push_back(1);
968 hw_shape->push_back(host_shape[0]);
969 return true;
970 default:
971 auto size = host_shape.size();
972 if (size < kDim2) {
973 MS_LOG(ERROR) << "Illegal size: " << size;
974 return false;
975 }
976 int64_t times = 1;
977 for (size_t i = 0; i != size - kDim2; i++) {
978 times *= host_shape[i];
979 }
980 hw_shape->push_back(times);
981 hw_shape->push_back(host_shape[size - kDim2]);
982 hw_shape->push_back(host_shape[size - kDim1]);
983 return true;
984 }
985 }
986
NCHW_TO_4D(const FormatArgs & args,void * result)987 bool FormatTransfer::NCHW_TO_4D(const FormatArgs &args, void *result) {
988 // trans nchw to NHWC or HWCN
989 MS_LOG(DEBUG) << "Trans format from nchw to " << args.device_format;
990 MS_EXCEPTION_IF_NULL(result);
991 int64_t size = 0;
992 if (!CheckArgs(args, &size)) {
993 MS_LOG(ERROR) << "Check args failed.";
994 return false;
995 }
996 auto n = args.host_shape[kN];
997 auto c = args.host_shape[kC];
998 auto h = args.host_shape[kH];
999 auto w = args.host_shape[kW];
1000 for (int64_t ni = 0; ni < n; ni++) {
1001 for (int64_t ci = 0; ci < c; ci++) {
1002 for (int64_t hi = 0; hi < h; hi++) {
1003 for (int64_t wi = 0; wi < w; wi++) {
1004 auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1005 int64_t dst_idx = 0;
1006 if (args.device_format == kOpFormat_NHWC) {
1007 dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1008 } else if (args.device_format == kOpFormat_HWCN) {
1009 dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1010 }
1011 SetData(size, false, src_idx, dst_idx, args, result);
1012 }
1013 }
1014 }
1015 }
1016 return true;
1017 }
1018
TO_NCHW(const FormatArgs & args,void * result)1019 bool FormatTransfer::TO_NCHW(const FormatArgs &args, void *result) {
1020 MS_LOG(DEBUG) << "Trans format to nchw from " << args.device_format;
1021 MS_EXCEPTION_IF_NULL(result);
1022 int64_t size = 0;
1023 if (!CheckArgs(args, &size)) {
1024 MS_LOG(ERROR) << "Check args failed.";
1025 return false;
1026 }
1027 auto n = args.host_shape[kN];
1028 auto c = args.host_shape[kC];
1029 auto h = args.host_shape[kH];
1030 auto w = args.host_shape[kW];
1031 for (int64_t ni = 0; ni < n; ni++) {
1032 for (int64_t ci = 0; ci < c; ci++) {
1033 for (int64_t hi = 0; hi < h; hi++) {
1034 for (int64_t wi = 0; wi < w; wi++) {
1035 auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1036 int64_t src_idx = 0;
1037 if (args.device_format == kOpFormat_NHWC) {
1038 src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1039 } else if (args.device_format == kOpFormat_HWCN) {
1040 src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1041 }
1042 SetData(size, false, src_idx, dst_idx, args, result);
1043 }
1044 }
1045 }
1046 }
1047 return true;
1048 }
1049
NCHW_TO_FRAC_Z(const FormatArgs & args,void * result)1050 bool FormatTransfer::NCHW_TO_FRAC_Z(const FormatArgs &args, void *result) {
1051 MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
1052 MS_EXCEPTION_IF_NULL(result);
1053 auto size = Common4DCheck(args);
1054 auto n = args.host_shape[kN];
1055 auto c = args.host_shape[kC];
1056 auto h = args.host_shape[kH];
1057 auto w = args.host_shape[kW];
1058 auto c0 = GetCubeSizeByType(args.src_data_type);
1059 auto c1 = DivCeil(c, c0);
1060 auto hw = h * w;
1061 auto chw = c * hw;
1062 auto hwc0 = hw * c0;
1063 auto nchw = n * chw;
1064
1065 auto hf_cnt = DivCeil(n, kNiSize);
1066 auto vf_cnt = c1 * hw;
1067 auto fractal_ele_cnt = c0 * kNiSize;
1068 auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
1069 auto dst_size = total_ele_cnt * size;
1070 if (dst_size != SizeToLong(args.device_size)) {
1071 MS_LOG(ERROR) << "Illegal total data size, "
1072 << "dst size is :" << dst_size << ", device size is :" << args.device_size;
1073 return false;
1074 }
1075
1076 for (int64_t vfi = 0; vfi < vf_cnt; vfi++) {
1077 auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
1078 for (int64_t hfi = 0; hfi < hf_cnt; hfi++) {
1079 auto gfi = vf_base_i + hfi; // global fractal matrix index
1080 auto src_n_offset = hfi * chw * kNiSize;
1081 auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
1082 for (int64_t row = 0; row < c0; row++) {
1083 auto src_ci = vfi / hw * c0 + row;
1084 auto src_row_offset = src_f_offset + row * hw;
1085 for (int64_t col = 0; col < kNiSize; col++) {
1086 auto src_ni = hfi * kNiSize + col;
1087 auto src_idx = src_row_offset + chw * col;
1088 auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
1089 auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
1090 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1091 }
1092 }
1093 }
1094 }
1095 return true;
1096 }
1097
NCHW_TO_FRAC_NZ(const FormatArgs & args,void * result)1098 bool FormatTransfer::NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result) {
1099 MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
1100 MS_EXCEPTION_IF_NULL(result);
1101 ShapeVector hw_shape;
1102 if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
1103 MS_LOG(ERROR) << "Trans shape failed..";
1104 return false;
1105 }
1106 if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1107 MS_LOG(ERROR) << "Invalid shape size.";
1108 return false;
1109 }
1110 auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1111 if (size < 1) {
1112 MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1113 return false;
1114 }
1115
1116 auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1117 if (dst_size != SizeToLong(args.device_size)) {
1118 MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1119 return false;
1120 }
1121 auto times = hw_shape.at(0);
1122 auto h = hw_shape.at(hw_h);
1123 auto w = hw_shape.at(hw_w);
1124 auto hw = h * w;
1125
1126 auto shape_size = args.device_shape.size();
1127 auto w1 = args.device_shape[shape_size - fnz_w1];
1128 auto h1 = args.device_shape[shape_size - fnz_h1];
1129 auto h0 = args.device_shape[shape_size - fnz_h0];
1130 auto w0 = args.device_shape[shape_size - fnz_w0];
1131 auto h1h0w0 = h1 * h0 * w0;
1132 auto w1h1h0w0 = w1 * h1h0w0;
1133 auto num_w1 = w / w0;
1134
1135 for (int64_t times_idx = 0; times_idx < times; times_idx++) {
1136 auto times_head = times_idx * w1h1h0w0;
1137 auto src_times_head = times_idx * hw;
1138 for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1139 auto h1h0_head = times_head + h1h0_idx * w0;
1140 auto src_h_head = src_times_head + h1h0_idx * w;
1141 for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1142 for (int64_t i = 0; i < w0; ++i) {
1143 int64_t src_idx = src_h_head + w1_idx * w0 + i;
1144 int64_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
1145 SetData(size, false, src_idx, dst_idx, args, result);
1146 }
1147 }
1148 auto w1_head = num_w1 * w0;
1149 for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1150 auto src_w_idx = w1_head + w0_idx;
1151 int64_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1152 int64_t src_idx = src_h_head + src_w_idx;
1153 SetData(size, false, src_idx, dst_idx, args, result);
1154 }
1155 }
1156 }
1157 return true;
1158 }
1159
NCHW_TO_FRAC_ZC04(const FormatArgs & args,void * result)1160 bool FormatTransfer::NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result) {
1161 // trans nchw to FracZc04
1162 MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
1163 MS_EXCEPTION_IF_NULL(result);
1164 int64_t size = 0;
1165 if (!CheckArgs(args, &size)) {
1166 MS_LOG(ERROR) << "Check args failed.";
1167 return false;
1168 }
1169 auto cube = GetCubeSizeByType(args.src_data_type);
1170 auto n = args.host_shape[kN];
1171 auto c = args.host_shape[kC];
1172 auto h = args.host_shape[kH];
1173 auto w = args.host_shape[kW];
1174 const int64_t c0 = 4;
1175 auto c1 = DivCeil(c, c0);
1176 auto hwc0 = h * w * c0;
1177 auto hwc = h * w * c;
1178 auto nhwc = n * h * w * c;
1179 auto n_cnt = DivCeil(n, kNiSize);
1180 auto v_cnt = DivCeil(h * w * c0 * c1, cube);
1181 int64_t dst_idx = 0;
1182
1183 for (int64_t vi = 0; vi < v_cnt; vi++) {
1184 for (int64_t ni = 0; ni < n_cnt; ni++) {
1185 for (int64_t col = 0; col < kNiSize; col++) {
1186 for (int64_t row = 0; row < kNiSize; row++) {
1187 int64_t cur_cube_n = kNiSize * ni + col;
1188 int64_t cur_cube_c1hwc0 = kNiSize * vi + row;
1189 auto desc_g = cur_cube_n / n;
1190 auto desc_n = cur_cube_n % n;
1191 auto desc_c1 = cur_cube_c1hwc0 / hwc0;
1192 auto desc_c0 = cur_cube_c1hwc0 % c0;
1193 auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
1194 auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
1195 auto c_idx = desc_c1 * c0 + desc_c0;
1196 auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
1197 auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
1198 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1199 dst_idx++;
1200 }
1201 }
1202 }
1203 }
1204 return true;
1205 }
1206
NCHW_TO_NC1HWC0(const FormatArgs & args,void * result)1207 bool FormatTransfer::NCHW_TO_NC1HWC0(const FormatArgs &args, void *result) {
1208 MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
1209 MS_EXCEPTION_IF_NULL(result);
1210 auto size = Common4DCheck(args);
1211 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1212 if (total_size != SizeToLong(args.device_size)) {
1213 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1214 return false;
1215 }
1216
1217 auto n = args.host_shape[kN];
1218 auto c = args.host_shape[kC];
1219 auto h = args.host_shape[kH];
1220 auto w = args.host_shape[kW];
1221 auto c0 = GetCubeSizeByType(args.src_data_type);
1222 if (args.device_format == kOpFormat_NC1HWC0_C04) {
1223 c0 = kCubeSize_C04;
1224 }
1225 auto c1 = DivCeil(c, c0);
1226 auto hw = h * w;
1227 auto chw = c * hw;
1228 auto c1hwc0 = c1 * hw * c0;
1229 auto wc0 = w * c0;
1230
1231 for (int64_t n_idx = 0; n_idx < n; n_idx++) {
1232 int64_t n_head_addr = n_idx * c1hwc0;
1233 for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) {
1234 int64_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
1235 for (int64_t h_idx = 0; h_idx < h; h_idx++) {
1236 int64_t h_head_addr = c1_head_addr + h_idx * wc0;
1237 for (int64_t w_idx = 0; w_idx < w; w_idx++) {
1238 int64_t w_head_addr = h_head_addr + w_idx * c0;
1239 for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) {
1240 int64_t dst_idx = c0_idx + w_head_addr;
1241 int64_t c_idx = c0_idx + c1_idx * c0;
1242 int64_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
1243 auto pad_zero = c_idx >= c;
1244 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1245 }
1246 }
1247 }
1248 }
1249 }
1250 return true;
1251 }
1252
NCHW_TO_NC1HWC04(const FormatArgs & args,void * result)1253 bool FormatTransfer::NCHW_TO_NC1HWC04(const FormatArgs &args, void *result) {
1254 MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
1255 return NCHW_TO_NC1HWC0(args, result);
1256 }
1257
NCHW_TO_C1HWNCOC0(const FormatArgs & args,void * result)1258 bool FormatTransfer::NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result) {
1259 // trans nchw to c1hwncoc0
1260 MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
1261 MS_EXCEPTION_IF_NULL(result);
1262 int64_t size = 0;
1263 if (!CheckArgs(args, &size)) {
1264 MS_LOG(ERROR) << "Check args failed.";
1265 return false;
1266 }
1267 auto n = args.host_shape[kN];
1268 auto c = args.host_shape[kC];
1269 auto h = args.host_shape[kH];
1270 auto w = args.host_shape[kW];
1271 const int co_idx = 4;
1272 const int c0_idx = 5;
1273 auto c1 = args.device_shape[0];
1274 auto co = args.device_shape[co_idx];
1275 auto c0 = args.device_shape[c0_idx];
1276
1277 for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
1278 for (int64_t h_i = 0; h_i < h; h_i++) {
1279 for (int64_t w_i = 0; w_i < w; w_i++) {
1280 for (int64_t n_i = 0; n_i < n; n_i++) {
1281 for (int64_t co_i = 0; co_i < co; co_i++) {
1282 for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
1283 int64_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
1284 co_i * c0 + c0_i;
1285 int64_t c_i = c0_i + c1_i * c0;
1286 int64_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1287 auto pad_zero = !(c_i < c && c0_i == co_i);
1288 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1289 }
1290 }
1291 }
1292 }
1293 }
1294 }
1295 return true;
1296 }
1297
NCDHW_TO_NDC1HWC0(const FormatArgs & args,void * result)1298 bool FormatTransfer::NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result) {
1299 MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
1300 MS_EXCEPTION_IF_NULL(result);
1301
1302 if (args.host_shape.size() != kDim5) {
1303 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1304 return false;
1305 }
1306 auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1307 if (size < 1) {
1308 MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1309 return false;
1310 }
1311 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1312 if (total_size != SizeToLong(args.device_size)) {
1313 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1314 return false;
1315 }
1316
1317 auto n = args.host_shape[N_ncdhw];
1318 auto c = args.host_shape[C_ncdhw];
1319 auto d = args.host_shape[D_ncdhw];
1320 auto h = args.host_shape[H_ncdhw];
1321 auto w = args.host_shape[W_ncdhw];
1322 auto c0 = GetCubeSizeByType(args.src_data_type);
1323 auto c1 = DivCeil(c, c0);
1324 const int64_t cdhw = c * d * h * w;
1325 const int64_t dhw = d * h * w;
1326 const int64_t hw = h * w;
1327 const int64_t dc1hwc0 = d * c1 * h * w * c0;
1328 const int64_t c1hwc0 = c1 * h * w * c0;
1329 const int64_t hwc0 = h * w * c0;
1330 const int64_t wc0 = w * c0;
1331
1332 for (int64_t n_i = 0; n_i < n; n_i++) {
1333 int64_t n_head = n_i * dc1hwc0;
1334 for (int64_t d_i = 0; d_i < d; d_i++) {
1335 int64_t d_head = n_head + d_i * c1hwc0;
1336 for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
1337 int64_t c1_head = d_head + c1_i * hwc0;
1338 for (int64_t h_i = 0; h_i < h; h_i++) {
1339 int64_t h_head = c1_head + h_i * wc0;
1340 for (int64_t w_i = 0; w_i < w; w_i++) {
1341 int64_t w_head = h_head + w_i * c0;
1342 for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
1343 int64_t dst_i = c0_i + w_head;
1344 int64_t c_i = c0_i + c1_i * c0;
1345 int64_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
1346 auto pad_zero = c_i >= c;
1347 SetData(size, pad_zero, src_i, dst_i, args, result);
1348 }
1349 }
1350 }
1351 }
1352 }
1353 }
1354 return true;
1355 }
1356
NCDHW_TO_FRAC_Z3D(const FormatArgs & args,void * result)1357 bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
1358 MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
1359 MS_EXCEPTION_IF_NULL(result);
1360
1361 if (args.host_shape.size() != kDim5) {
1362 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1363 return false;
1364 }
1365 auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1366 if (size < 1) {
1367 MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1368 return false;
1369 }
1370 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1371 if (total_size != SizeToLong(args.device_size)) {
1372 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1373 return false;
1374 }
1375
1376 auto n = args.host_shape[N_ncdhw];
1377 auto c = args.host_shape[C_ncdhw];
1378 auto d = args.host_shape[D_ncdhw];
1379 auto h = args.host_shape[H_ncdhw];
1380 auto w = args.host_shape[W_ncdhw];
1381
1382 auto n1n0 = DivCeil(n, kNiSize) * kNiSize;
1383 auto c0 = GetCubeSizeByType(args.src_data_type);
1384 auto c1 = DivCeil(c, c0);
1385 auto hw = h * w;
1386 auto dhw = d * hw;
1387 auto cdhw = c * dhw;
1388 auto n1n0c0 = n1n0 * c0;
1389 auto wn1n0c0 = w * n1n0c0;
1390 auto hwn1n0c0 = h * wn1n0c0;
1391 auto c1hwn1n0c0 = c1 * hwn1n0c0;
1392
1393 for (int64_t d_i = 0; d_i < d; d_i++) {
1394 for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
1395 for (int64_t h_i = 0; h_i < h; h_i++) {
1396 for (int64_t w_i = 0; w_i < w; w_i++) {
1397 for (int64_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
1398 for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
1399 auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
1400 // ncdhw
1401 int64_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
1402 auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
1403 SetData(size, pad_zero, src_i, dst_i, args, result);
1404 }
1405 }
1406 }
1407 }
1408 }
1409 }
1410 return true;
1411 }
1412
NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs & args,void * result,bool to_device,int64_t groups)1413 bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
1414 MS_EXCEPTION_IF_NULL(result);
1415 auto size = Common4DCheck(args);
1416 auto n_dim = args.host_shape[kN];
1417 auto c_dim = args.host_shape[kC];
1418 auto h_dim = args.host_shape[kH];
1419 auto w_dim = args.host_shape[kW];
1420 auto d_dim = 1;
1421 auto cin_ori = c_dim;
1422 if (groups <= 0) {
1423 MS_LOG(INTERNAL_EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
1424 }
1425 // cppcheck-suppress *
1426 auto cout_ori = n_dim / groups;
1427 if (cin_ori == 0 || cout_ori == 0) {
1428 MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
1429 return false;
1430 }
1431 auto cube_k = GetCubeSizeByType(args.src_data_type);
1432 auto e_mult = std::min(Lcm(Lcm(cin_ori, cube_k) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), groups);
1433 if (e_mult == 0) {
1434 MS_LOG(INTERNAL_EXCEPTION) << "The value of e_mult should be greater than 0, but got " << e_mult;
1435 }
1436 auto cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k;
1437 auto cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
1438 // cppcheck-suppress *
1439 auto c1_dim = cin_opt / cube_k;
1440 auto dst_size =
1441 to_device ? abstract::ShapeSize(args.device_shape) * size : abstract::ShapeSize(args.host_shape) * size;
1442 if (dst_size == 0) {
1443 return true;
1444 }
1445 auto ret = memset_s(result, LongToSize(dst_size), 0, LongToSize(dst_size));
1446 if (ret != EOK) {
1447 MS_LOG(ERROR) << "memset failed";
1448 return false;
1449 }
1450 for (int64_t g = 0; g < groups; ++g) {
1451 for (int64_t d = 0; d < d_dim; ++d) {
1452 for (int64_t c = 0; c < c_dim; ++c) {
1453 for (int64_t h = 0; h < h_dim; ++h) {
1454 for (int64_t w = 0; w < w_dim; ++w) {
1455 for (int64_t n = 0; n < cout_ori; ++n) {
1456 int64_t e_val = g % e_mult;
1457 int64_t dst_ci = e_val * cin_ori + c;
1458 int64_t dst_co = e_val * cout_ori + n;
1459 int64_t src_co = g * cout_ori + n;
1460 int64_t temporary = dst_ci % cube_k;
1461 int64_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * cube_k +
1462 d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
1463 (dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + h * w_dim * cout_opt * cube_k +
1464 w * cout_opt * cube_k + dst_co * cube_k + temporary;
1465 int64_t hst_idx =
1466 src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
1467 if (to_device) {
1468 SetData(size, false, hst_idx, dev_idx, args, result);
1469 } else {
1470 SetData(size, false, dev_idx, hst_idx, args, result);
1471 }
1472 }
1473 }
1474 }
1475 }
1476 }
1477 }
1478 return true;
1479 }
1480
NC1HWC0_TO_NCHW(const FormatArgs & args,void * result)1481 bool FormatTransfer::NC1HWC0_TO_NCHW(const FormatArgs &args, void *result) {
1482 MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
1483 MS_EXCEPTION_IF_NULL(result);
1484 auto size = Common4DCheck(args);
1485 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1486 if (total_size != SizeToLong(args.device_size)) {
1487 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1488 return false;
1489 }
1490
1491 auto n = args.host_shape[kN];
1492 auto c = args.host_shape[kC];
1493 auto h = args.host_shape[kH];
1494 auto w = args.host_shape[kW];
1495 auto c1 = args.device_shape[kDim1];
1496 auto c0 = args.device_shape[kDim4];
1497
1498 auto hw = h * w;
1499 auto chw = c * hw;
1500 auto wc0 = w * c0;
1501 auto hwc0 = h * wc0;
1502 auto c1hwc0 = c1 * hwc0;
1503
1504 for (int64_t n_idx = 0; n_idx < n; n_idx++) {
1505 int64_t n_head_addr = n_idx * chw;
1506 for (int64_t c_idx = 0; c_idx < c; c_idx++) {
1507 int64_t c_head_addr = n_head_addr + c_idx * hw;
1508 for (int64_t h_idx = 0; h_idx < h; h_idx++) {
1509 int64_t h_head_addr = c_head_addr + h_idx * w;
1510 for (int64_t w_idx = 0; w_idx < w; w_idx++) {
1511 int64_t dst_idx = h_head_addr + w_idx;
1512 int64_t c1_idx = c_idx / c0;
1513 int64_t c0_idx = c_idx % c0;
1514 int64_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
1515 SetData(size, false, src_idx, dst_idx, args, result);
1516 }
1517 }
1518 }
1519 }
1520 return true;
1521 }
1522
NC1HWC04_TO_NCHW(const FormatArgs & args,void * result)1523 bool FormatTransfer::NC1HWC04_TO_NCHW(const FormatArgs &args, void *result) {
1524 MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
1525 return NC1HWC0_TO_NCHW(args, result);
1526 }
1527
C1HWNCOC0_TO_NCHW(const FormatArgs & args,void * result)1528 bool FormatTransfer::C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result) {
1529 // trans c1hwncoc0 to nchw
1530 MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
1531 MS_EXCEPTION_IF_NULL(result);
1532 int64_t size = 0;
1533 if (!CheckArgs(args, &size)) {
1534 MS_LOG(ERROR) << "Check args failed.";
1535 return false;
1536 }
1537 auto n = args.host_shape[kN];
1538 auto c = args.host_shape[kC];
1539 auto h = args.host_shape[kH];
1540 auto w = args.host_shape[kW];
1541 const int co_idx = 4;
1542 const int c0_idx = 5;
1543 auto co = args.device_shape[co_idx];
1544 auto c0 = args.device_shape[c0_idx];
1545 auto cube_k = GetCubeSizeByType(args.src_data_type);
1546 for (int64_t n_i = 0; n_i < n; n_i++) {
1547 for (int64_t c_i = 0; c_i < c; c_i++) {
1548 for (int64_t h_i = 0; h_i < h; h_i++) {
1549 for (int64_t w_i = 0; w_i < w; w_i++) {
1550 int64_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1551 int64_t c1_i = c_i / cube_k;
1552 int64_t c0_i = c_i % cube_k;
1553 int64_t co_i = c0_i;
1554 int64_t src_idx =
1555 c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
1556 SetData(size, false, src_idx, dst_idx, args, result);
1557 }
1558 }
1559 }
1560 }
1561 return true;
1562 }
1563
FRAC_Z_TO_NCHW(const FormatArgs & args,void * result)1564 bool FormatTransfer::FRAC_Z_TO_NCHW(const FormatArgs &args, void *result) {
1565 MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
1566 MS_EXCEPTION_IF_NULL(result);
1567 auto size = Common4DCheck(args);
1568 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1569 if (total_size != SizeToLong(args.device_size)) {
1570 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1571 return false;
1572 }
1573
1574 auto n0 = args.device_shape.at(fz_n0);
1575 auto ni = args.device_shape.at(fz_ni);
1576 auto c0 = args.device_shape.at(fz_c0);
1577 auto n = args.host_shape[kN];
1578 auto c = args.host_shape[kC];
1579 auto h = args.host_shape[kH];
1580 auto w = args.host_shape[kW];
1581 auto nc = ni * n0;
1582 auto ncc0 = nc * c0;
1583 auto wncc0 = w * ncc0;
1584 auto hwncc0 = h * wncc0;
1585 auto hw = h * w;
1586 auto chw = c * hw;
1587
1588 for (int64_t n_idx = 0; n_idx < n; n_idx++) {
1589 int64_t n_head_addr = n_idx * chw;
1590 for (int64_t c_idx = 0; c_idx < c; c_idx++) {
1591 int64_t c_head_addr = n_head_addr + c_idx * hw;
1592 for (int64_t h_idx = 0; h_idx < h; h_idx++) {
1593 int64_t h_head_addr = c_head_addr + h_idx * w;
1594 for (int64_t w_idx = 0; w_idx < w; w_idx++) {
1595 auto dst_idx = h_head_addr + w_idx;
1596 auto c1_idx = c_idx / c0;
1597 auto c0_idx = c_idx % c0;
1598 auto nc_idx = n_idx;
1599 auto src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
1600 SetData(size, false, src_idx, dst_idx, args, result);
1601 }
1602 }
1603 }
1604 }
1605 return true;
1606 }
1607
FRAC_NZ_TO_NCHW(const FormatArgs & args,void * result)1608 bool FormatTransfer::FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result) {
1609 MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
1610 MS_EXCEPTION_IF_NULL(result);
1611 ShapeVector hw_shape;
1612 if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
1613 MS_LOG(ERROR) << "Trans shape failed..";
1614 return false;
1615 }
1616 if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1617 MS_LOG(ERROR) << "Invalid shape size.";
1618 return false;
1619 }
1620 auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1621 if (size < 1) {
1622 MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1623 return false;
1624 }
1625
1626 auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1627 if (dst_size != SizeToLong(args.device_size)) {
1628 MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1629 return false;
1630 }
1631 auto times = hw_shape.at(0);
1632 auto h = hw_shape.at(hw_h);
1633 auto w = hw_shape.at(hw_w);
1634 auto hw = h * w;
1635
1636 auto shape_size = args.device_shape.size();
1637 auto w1 = args.device_shape[shape_size - fnz_w1];
1638 auto h1 = args.device_shape[shape_size - fnz_h1];
1639 auto h0 = args.device_shape[shape_size - fnz_h0];
1640 auto w0 = args.device_shape[shape_size - fnz_w0];
1641 auto h1h0w0 = h1 * h0 * w0;
1642 auto w1h1h0w0 = w1 * h1h0w0;
1643 auto num_w1 = w / w0;
1644
1645 for (int64_t times_idx = 0; times_idx < times; times_idx++) {
1646 auto times_head = times_idx * w1h1h0w0;
1647 auto src_times_head = times_idx * hw;
1648 for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1649 auto h1h0_head = times_head + h1h0_idx * w0;
1650 auto src_h_head = src_times_head + h1h0_idx * w;
1651 for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1652 for (int64_t i = 0; i < w0; ++i) {
1653 int64_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
1654 int64_t dst_idx = src_h_head + w1_idx * w0 + i;
1655 SetData(size, false, src_idx, dst_idx, args, result);
1656 }
1657 }
1658 auto w1_head = num_w1 * w0;
1659 for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1660 auto src_w_idx = w1_head + w0_idx;
1661 int64_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1662 int64_t dst_idx = src_h_head + src_w_idx;
1663 SetData(size, false, src_idx, dst_idx, args, result);
1664 }
1665 }
1666 }
1667 return true;
1668 }
1669
FRAC_Z3D_TO_NCDHW(const FormatArgs & args,void * result)1670 bool FormatTransfer::FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result) {
1671 MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
1672 MS_EXCEPTION_IF_NULL(result);
1673
1674 if (args.host_shape.size() != kDim5) {
1675 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1676 return false;
1677 }
1678 auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1679 if (size < 1) {
1680 MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1681 return false;
1682 }
1683 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1684 if (total_size != SizeToLong(args.device_size)) {
1685 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1686 return false;
1687 }
1688 auto n = args.host_shape[N_ncdhw];
1689 auto c = args.host_shape[C_ncdhw];
1690 auto d = args.host_shape[D_ncdhw];
1691 auto h = args.host_shape[H_ncdhw];
1692 auto w = args.host_shape[W_ncdhw];
1693 const int kFZ3D_C0 = 3;
1694 auto c0 = args.device_shape[kFZ3D_C0];
1695 auto cube_k = GetCubeSizeByType(args.src_data_type);
1696 auto c1 = DivCeil(c, cube_k);
1697 auto n1n0 = DivCeil(n, kNiSize) * kNiSize;
1698 auto n1n0c0 = n1n0 * c0;
1699 auto wn1n0c0 = w * n1n0c0;
1700 auto hwn1n0c0 = h * wn1n0c0;
1701 auto c1hwn1n0c0 = c1 * hwn1n0c0;
1702 auto hw = h * w;
1703 auto dhw = d * hw;
1704 auto cdhw = c * dhw;
1705
1706 for (int64_t n_i = 0; n_i < n; n_i++) {
1707 int64_t n_head = n_i * cdhw;
1708 for (int64_t c_i = 0; c_i < c; c_i++) {
1709 int64_t c_head = n_head + c_i * dhw;
1710 for (int64_t d_i = 0; d_i < d; d_i++) {
1711 int64_t d_head = c_head + d_i * hw;
1712 for (int64_t h_i = 0; h_i < h; h_i++) {
1713 int64_t h_head = d_head + h_i * w;
1714 for (int64_t w_i = 0; w_i < w; w_i++) {
1715 int64_t dst_i = h_head + w_i;
1716 int64_t c1_i = c_i / c0;
1717 int64_t c0_i = c_i % c0;
1718 int64_t nc_i = n_i;
1719 int64_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
1720 SetData(size, false, src_i, dst_i, args, result);
1721 }
1722 }
1723 }
1724 }
1725 }
1726 return true;
1727 }
1728
NDC1HWC0_TO_NCDHW(const FormatArgs & args,void * result)1729 bool FormatTransfer::NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result) {
1730 MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
1731 MS_EXCEPTION_IF_NULL(result);
1732
1733 if (args.host_shape.size() != kDim5) {
1734 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1735 return false;
1736 }
1737 auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1738 if (size < 1) {
1739 MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1740 return false;
1741 }
1742 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1743 if (total_size != SizeToLong(args.device_size)) {
1744 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1745 return false;
1746 }
1747 auto n = args.host_shape[N_ncdhw];
1748 auto c = args.host_shape[C_ncdhw];
1749 auto d = args.host_shape[D_ncdhw];
1750 auto h = args.host_shape[H_ncdhw];
1751 auto w = args.host_shape[W_ncdhw];
1752 auto c1 = args.device_shape[C1_ndc1hwc0];
1753 auto c0 = args.device_shape[C0_ndc1hwc0];
1754 const int64_t cdhw = c * d * h * w;
1755 const int64_t dhw = d * h * w;
1756 const int64_t hw = h * w;
1757 const int64_t dc1hwc0 = d * c1 * h * w * c0;
1758 const int64_t c1hwc0 = c1 * h * w * c0;
1759 const int64_t hwc0 = h * w * c0;
1760 const int64_t wc0 = w * c0;
1761
1762 for (int64_t n_i = 0; n_i < n; n_i++) {
1763 int64_t n_head = n_i * cdhw;
1764 for (int64_t c_i = 0; c_i < c; c_i++) {
1765 int64_t c_head = n_head + c_i * dhw;
1766 for (int64_t d_i = 0; d_i < d; d_i++) {
1767 int64_t d_head = c_head + d_i * hw;
1768 for (int64_t h_i = 0; h_i < h; h_i++) {
1769 int64_t h_head = d_head + h_i * w;
1770 for (int64_t w_i = 0; w_i < w; w_i++) {
1771 int64_t dst_i = h_head + w_i;
1772 int64_t c1_i = c_i / c0;
1773 int64_t c0_i = c_i % c0;
1774 auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
1775 SetData(size, false, src_idx, dst_i, args, result);
1776 }
1777 }
1778 }
1779 }
1780 }
1781 return true;
1782 }
1783
FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs & args,void * result,int64_t groups)1784 bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups) {
1785 MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
1786 return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, false, groups);
1787 }
1788
Common4DCheck(const FormatArgs & args)1789 int64_t FormatTransfer::Common4DCheck(const FormatArgs &args) {
1790 if (args.host_shape.size() != kDim4) {
1791 MS_LOG(INTERNAL_EXCEPTION) << "Invalid host shape, host shape dims:" << args.host_shape.size()
1792 << ", expect dims:" << kNchwDims;
1793 }
1794 auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1795 if (size < 1) {
1796 MS_LOG(INTERNAL_EXCEPTION) << "Illegal dtype: " << args.src_data_type;
1797 }
1798 return size;
1799 }
1800
1801 // ######################## RANGE TRANS ########################
GetRealRange(const RangePair & ori_range,const std::string & format,const TypeId & type,const std::string & padding_str) const1802 RangePair ShapeRangeTransfer::GetRealRange(const RangePair &ori_range, const std::string &format, const TypeId &type,
1803 const std::string &padding_str) const {
1804 const std::set<std::string> no_need_change = {kOpFormat_ND, kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NCDHW};
1805 using RangeTransfer = std::function<RangePair(const RangePair &, const TypeId &)>;
1806 const std::map<std::string, RangeTransfer> format_range_map = {{kOpFormat_NHWC, NHWCRange},
1807 {kOpFormat_HWCN, HWCNRange},
1808 {kOpFormat_FRAC_Z, FRAC_ZRange},
1809 {kOpFormat_NC1HWC0, NC1HWC0Range},
1810 {kOpFormat_NDC1HWC0, NDC1HWC0Range},
1811 {kOpFormat_C1HWNCoC0, C1HWNCOC0Range},
1812 {kOpFormat_NC1HWC0_C04, NC1HWC04Range},
1813 {kOpFormat_FRACTAL_Z_3D, FRAC_Z_3DRange},
1814 {kOpFormat_FRACTAL_Z_C04, FRAC_ZC04Range}};
1815 if (no_need_change.find(format) != no_need_change.end()) {
1816 return ori_range;
1817 }
1818 // kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_FRAC_NZ no need pad range
1819 if (format == kOpFormat_FRACTAL_ZN_LSTM) {
1820 return FRAC_ZN_LSTMRange(ori_range, type);
1821 }
1822 if (format == kOpFormat_FRAC_NZ) {
1823 return FRAC_NZRange(ori_range, type);
1824 }
1825 auto temp_range = ori_range;
1826 if (ori_range.size() < kDim4 && !IsOneOf3DFormat(format)) {
1827 MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 4, so padding the range firstly";
1828 temp_range = PaddingRangeTo4D(ori_range, padding_str);
1829 }
1830 if (ori_range.size() < kDim5 && IsOneOf3DFormat(format)) {
1831 MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 5, so padding the range firstly";
1832 temp_range = PaddingRangeTo5D(ori_range, padding_str);
1833 }
1834 auto iter = format_range_map.find(format);
1835 if (iter == format_range_map.end()) {
1836 MS_LOG(INFO) << "Can not find a supported format: " << format << ", using default range";
1837 return ori_range;
1838 }
1839 return iter->second(temp_range, type);
1840 }
1841
NHWCRange(const RangePair & ori_range,const TypeId &)1842 RangePair ShapeRangeTransfer::NHWCRange(const RangePair &ori_range, const TypeId &) {
1843 RangePair dst_range;
1844 dst_range.push_back(ori_range[kN]);
1845 dst_range.push_back(ori_range[kH]);
1846 dst_range.push_back(ori_range[kW]);
1847 dst_range.push_back(ori_range[kC]);
1848 return dst_range;
1849 }
1850
HWCNRange(const RangePair & ori_range,const TypeId &)1851 RangePair ShapeRangeTransfer::HWCNRange(const RangePair &ori_range, const TypeId &) {
1852 RangePair dst_range;
1853 dst_range.push_back(ori_range[kH]);
1854 dst_range.push_back(ori_range[kW]);
1855 dst_range.push_back(ori_range[kC]);
1856 dst_range.push_back(ori_range[kN]);
1857 return dst_range;
1858 }
1859
NC1HWC04Range(const RangePair & ori_range,const TypeId &)1860 RangePair ShapeRangeTransfer::NC1HWC04Range(const RangePair &ori_range, const TypeId &) {
1861 RangePair dst_range;
1862 const std::pair<int64_t, int64_t> c0 = {k4, k4};
1863 auto tmp_max = CalMaxShape(ori_range[kC].second, (ori_range[kC].second + k4 - 1) / k4);
1864 const std::pair<int64_t, int64_t> c1 = {(ori_range[kC].first + k4 - 1) / k4, tmp_max};
1865 dst_range.push_back(ori_range[kN]);
1866 dst_range.push_back(c1);
1867 dst_range.push_back(ori_range[kH]);
1868 dst_range.push_back(ori_range[kW]);
1869 dst_range.push_back(c0);
1870 return dst_range;
1871 }
1872
FRAC_ZC04Range(const RangePair & ori_range,const TypeId &)1873 RangePair ShapeRangeTransfer::FRAC_ZC04Range(const RangePair &ori_range, const TypeId &) {
1874 RangePair dst_range;
1875 const std::pair<int64_t, int64_t> c0 = {k4, k4};
1876 const std::pair<int64_t, int64_t> c16 = {kNiSize, kNiSize};
1877
1878 auto tmp_max = CalMaxShape(c0.second * ori_range[kH].second * ori_range[kW].second,
1879 (c0.second * ori_range[kH].second * ori_range[kW].second + kNiSize - 1) / kNiSize);
1880 const std::pair<int64_t, int64_t> first_dim = {
1881 (c0.first * ori_range[kH].first * ori_range[kW].first + kNiSize - 1) / kNiSize, tmp_max};
1882
1883 tmp_max = CalMaxShape(ori_range[kN].second, (ori_range[kN].second + kNiSize - 1) / kNiSize);
1884 const std::pair<int64_t, int64_t> no = {(ori_range[kN].first + kNiSize - 1) / kNiSize, tmp_max};
1885 dst_range.push_back(first_dim);
1886 dst_range.push_back(no);
1887 dst_range.push_back(c16);
1888 dst_range.push_back(c16);
1889 return dst_range;
1890 }
1891
FRAC_ZRange(const RangePair & ori_range,const TypeId & type)1892 RangePair ShapeRangeTransfer::FRAC_ZRange(const RangePair &ori_range, const TypeId &type) {
1893 RangePair dst_range;
1894 auto cube = GetCubeSizeByType(type);
1895 const std::pair<int64_t, int64_t> c0 = {cube, cube};
1896
1897 auto tmp_max = CalMaxShape(ori_range[kN].second, ((ori_range[kN].second + kNiSize - 1) / kNiSize) * kNiSize);
1898
1899 const std::pair<int64_t, int64_t> cout16 = {((ori_range[kN].first + kNiSize - 1) / kNiSize) * kNiSize, tmp_max};
1900
1901 tmp_max = CalMaxShape(ori_range[kC].second, ((ori_range[kC].second + cube - 1) / cube) * cube);
1902 const std::pair<int64_t, int64_t> cin16 = {((ori_range[kC].first + cube - 1) / cube) * cube, tmp_max};
1903
1904 tmp_max = CalMaxShape(ori_range[kH].second * ori_range[kW].second * cin16.second,
1905 ori_range[kH].second * ori_range[kW].second * cin16.second / cube);
1906 const std::pair<int64_t, int64_t> r0 = {ori_range[kH].first * ori_range[kW].first * cin16.first / cube, tmp_max};
1907
1908 tmp_max = CalMaxShape(cin16.second, cout16.second / kNiSize);
1909 const std::pair<int64_t, int64_t> r1 = {cout16.first / kNiSize, tmp_max};
1910 const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
1911 dst_range.push_back(r0);
1912 dst_range.push_back(r1);
1913 dst_range.push_back(co);
1914 dst_range.push_back(c0);
1915 return dst_range;
1916 }
1917
FRAC_NZRange(const RangePair & ori_range,const TypeId & type)1918 RangePair ShapeRangeTransfer::FRAC_NZRange(const RangePair &ori_range, const TypeId &type) {
1919 RangePair dst_range;
1920 auto cube = GetCubeSizeByType(type);
1921 auto ori_size = ori_range.size();
1922 if (ori_size < kDims2) {
1923 return ori_range;
1924 } else {
1925 (void)std::copy(ori_range.begin(), ori_range.end() - kDims2, std::back_inserter(dst_range));
1926 }
1927 const std::pair<int64_t, int64_t> c0 = {cube, cube};
1928 auto tmp_max = CalMaxShape(ori_range[ori_size - 1].second, (ori_range[ori_size - 1].second - 1) / cube + 1);
1929 const std::pair<int64_t, int64_t> w1 = {(ori_range[ori_size - 1].first - 1) / cube + 1, tmp_max};
1930 tmp_max = CalMaxShape(ori_range[ori_size - kDims2].second, (ori_range[ori_size - kDims2].second - 1) / kNiSize + 1);
1931 const std::pair<int64_t, int64_t> h1 = {(ori_range[ori_size - kDims2].first - 1) / kNiSize + 1, tmp_max};
1932 const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
1933 dst_range.push_back(w1);
1934 dst_range.push_back(h1);
1935 dst_range.push_back(co);
1936 dst_range.push_back(c0);
1937 return dst_range;
1938 }
1939
NC1HWC0Range(const RangePair & ori_range,const TypeId & type)1940 RangePair ShapeRangeTransfer::NC1HWC0Range(const RangePair &ori_range, const TypeId &type) {
1941 RangePair dst_range;
1942 auto cube = GetCubeSizeByType(type);
1943 const std::pair<int64_t, int64_t> c0 = {cube, cube};
1944 auto tmp_max = CalMaxShape(ori_range[kC].second, (ori_range[kC].second + cube - 1) / cube);
1945 const std::pair<int64_t, int64_t> c1 = {(ori_range[kC].first + cube - 1) / cube, tmp_max};
1946 dst_range.push_back(ori_range[kN]);
1947 dst_range.push_back(c1);
1948 dst_range.push_back(ori_range[kH]);
1949 dst_range.push_back(ori_range[kW]);
1950 dst_range.push_back(c0);
1951 return dst_range;
1952 }
1953
FRAC_ZN_LSTMRange(const RangePair & ori_range,const TypeId &)1954 RangePair ShapeRangeTransfer::FRAC_ZN_LSTMRange(const RangePair &ori_range, const TypeId &) {
1955 RangePair dst_range;
1956 const std::pair<int64_t, int64_t> c0 = {k4, k4};
1957 const std::pair<int64_t, int64_t> c16 = {k4, k4};
1958
1959 auto tmp_max = CalMaxShape(ori_range[kN].second, ori_range[kN].second / c0.second);
1960 const std::pair<int64_t, int64_t> h = {ori_range[kN].first / c0.first, tmp_max};
1961
1962 tmp_max = CalMaxShape(ori_range[kC].second * h.second, ori_range[kC].second - h.second);
1963 const std::pair<int64_t, int64_t> i = {ori_range[kC].first - h.first, tmp_max};
1964
1965 tmp_max = CalMaxShape(i.second * h.second, (i.second + kCube16 - 1) / kCube16 + (h.second + kCube16 - 1) / kCube16);
1966 const std::pair<int64_t, int64_t> first_dim = {(i.first + kCube16 - 1) / kCube16 + (h.first + kCube16 - 1) / kCube16,
1967 tmp_max};
1968
1969 tmp_max = CalMaxShape(h.second, c0.second * ((h.second + kCube16 - 1) / kCube16));
1970 const std::pair<int64_t, int64_t> second = {c0.first * ((h.first + kCube16 - 1) / kCube16), tmp_max};
1971 dst_range.push_back(first_dim);
1972 dst_range.push_back(second);
1973 dst_range.push_back(c16);
1974 dst_range.push_back(c16);
1975 return dst_range;
1976 }
1977
NDC1HWC0Range(const RangePair & ori_range,const TypeId & type)1978 RangePair ShapeRangeTransfer::NDC1HWC0Range(const RangePair &ori_range, const TypeId &type) {
1979 RangePair dst_range;
1980 auto cube = GetCubeSizeByType(type);
1981 const std::pair<int64_t, int64_t> c0 = {cube, cube};
1982 auto tmp_max = CalMaxShape(ori_range[C_ncdhw].second, (ori_range[C_ncdhw].second + cube - 1) / cube);
1983 const std::pair<int64_t, int64_t> c1 = {(ori_range[C_ncdhw].first + cube - 1) / cube, tmp_max};
1984 dst_range.push_back(ori_range[N_ncdhw]);
1985 dst_range.push_back(ori_range[D_ncdhw]);
1986 dst_range.push_back(c1);
1987 dst_range.push_back(ori_range[H_ncdhw]);
1988 dst_range.push_back(ori_range[W_ncdhw]);
1989 dst_range.push_back(c0);
1990 return dst_range;
1991 }
1992
C1HWNCOC0Range(const RangePair & ori_range,const TypeId & type)1993 RangePair ShapeRangeTransfer::C1HWNCOC0Range(const RangePair &ori_range, const TypeId &type) {
1994 RangePair dst_range;
1995 auto cube = GetCubeSizeByType(type);
1996 const std::pair<int64_t, int64_t> c0 = {cube, cube};
1997 auto tmp_max = CalMaxShape(ori_range[kC].second, (ori_range[kC].second - 1) / cube + 1);
1998 const std::pair<int64_t, int64_t> r1 = {(ori_range[kC].first - 1) / cube + 1, tmp_max};
1999 dst_range.push_back(r1);
2000 dst_range.push_back(ori_range[kH]);
2001 dst_range.push_back(ori_range[kW]);
2002 dst_range.push_back(ori_range[kN]);
2003 dst_range.push_back(c0);
2004 dst_range.push_back(c0);
2005 return dst_range;
2006 }
2007
FRAC_Z_3DRange(const RangePair & ori_range,const TypeId & type)2008 RangePair ShapeRangeTransfer::FRAC_Z_3DRange(const RangePair &ori_range, const TypeId &type) {
2009 RangePair dst_range;
2010 auto cube = GetCubeSizeByType(type);
2011 const std::pair<int64_t, int64_t> c0 = {cube, cube};
2012 auto tmp_max = CalMaxShape(ori_range[C_ncdhw].second, (ori_range[C_ncdhw].second + cube - 1) / cube);
2013 const std::pair<int64_t, int64_t> c1 = {(ori_range[C_ncdhw].first + cube - 1) / cube, tmp_max};
2014
2015 tmp_max = CalMaxShape(ori_range[N_ncdhw].second, (ori_range[N_ncdhw].second + kNiSize - 1) / kNiSize);
2016 const std::pair<int64_t, int64_t> n1 = {(ori_range[N_ncdhw].first + kNiSize - 1) / kNiSize, tmp_max};
2017
2018 const int64_t r1_0 = ori_range[D_ncdhw].first * c1.first * ori_range[H_ncdhw].first * ori_range[W_ncdhw].first;
2019 const int64_t r1_1 =
2020 CalMaxShape(ori_range[D_ncdhw].second * c1.second * ori_range[H_ncdhw].second * ori_range[W_ncdhw].second,
2021 ori_range[D_ncdhw].second * c1.second * ori_range[H_ncdhw].second * ori_range[W_ncdhw].second);
2022 const std::pair<int64_t, int64_t> r1 = {r1_0, r1_1};
2023 dst_range.push_back(r1);
2024 dst_range.push_back(n1);
2025 dst_range.push_back(c1);
2026 dst_range.push_back(c0);
2027 return dst_range;
2028 }
InitInfo()2029 void FormatHelper::InitInfo() {
2030 info = {{kOpFormat_DEFAULT, FormatInfo(kOpFormat_DEFAULT, false)},
2031 {kOpFormat_NC1HWC0, FormatInfo(kOpFormat_NCHW, true)},
2032 {kOpFormat_ND, FormatInfo(kOpFormat_ND, false)},
2033 {kOpFormat_NCHW, FormatInfo(kOpFormat_NCHW, false)},
2034 {kOpFormat_NHWC, FormatInfo(kOpFormat_NHWC, false)},
2035 {kOpFormat_FRAC_NZ, FormatInfo(kOpFormat_ND, true)},
2036 {kOpFormat_FRAC_Z, FormatInfo(kOpFormat_NCHW, true)},
2037 {kOpFormat_NDHWC, FormatInfo(kOpFormat_NCDHW, false)},
2038 {kOpFormat_NCDHW, FormatInfo(kOpFormat_NCDHW, false)},
2039 {kOpFormat_NDC1HWC0, FormatInfo(kOpFormat_NCDHW, true)},
2040 {kOpFormat_FRACTAL_Z_3D, FormatInfo(kOpFormat_NCDHW, true)}};
2041 }
2042
GetInstance()2043 FormatHelper &FormatHelper::GetInstance() noexcept {
2044 static FormatHelper instance{};
2045 return instance;
2046 }
2047
GetBaseFormat(const std::string & format)2048 const std::string FormatHelper::GetBaseFormat(const std::string &format) {
2049 const auto &iter = info.find(format);
2050 if (iter != info.end()) {
2051 return iter->second.baseFormat;
2052 } else {
2053 return "";
2054 }
2055 }
2056
IsBaseFormatType(const std::string & format)2057 bool FormatHelper::IsBaseFormatType(const std::string &format) {
2058 const auto &iter = info.find(format);
2059 if (iter == info.end()) {
2060 return false;
2061 }
2062
2063 return iter->first == iter->second.baseFormat;
2064 }
2065
IsPadded(const std::string & format)2066 bool FormatHelper::IsPadded(const std::string &format) {
2067 auto itr = info.find(format);
2068 if (itr != info.end()) {
2069 return itr->second.isPadded;
2070 }
2071 MS_LOG(INFO) << "unknown format type:" << format;
2072 return true;
2073 }
2074
Clear()2075 void FormatHelper::Clear() { info.clear(); }
2076 } // namespace trans
2077 } // namespace mindspore
2078