• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2021-2023 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #include "runtime/device/ms_device_shape_transfer.h"
18 #include <functional>
19 #include <unordered_map>
20 #include <numeric>
21 #include <utility>
22 #include <algorithm>
23 
24 namespace mindspore {
25 namespace trans {
26 static const ShapeValueDType kShapeDimAny = abstract::Shape::kShapeDimAny;
27 
28 const int b1 = 1;
29 const int b2 = 2;
30 const int b4 = 4;
31 const int b8 = 8;
32 const int64_t kCubeSize = 16;
33 const int64_t kCube16 = kCubeSize;
34 const int64_t kCube32 = 32;
35 const int64_t kCube64 = 64;
36 const int64_t kCubeSize_C04 = 4;
37 const int64_t kNiSize = 16;
38 constexpr int kDims2 = 2;
39 constexpr int64_t k4 = 4;
40 static const std::set<TypeId> C0_64 = {kNumberTypeInt4};
41 static const std::set<TypeId> C0_32 = {kNumberTypeUInt8, kNumberTypeInt8};
42 namespace {
43 const size_t hw_h = 1;
44 const size_t hw_w = 2;
45 const size_t fnz_w1 = 4;
46 const size_t fnz_h1 = 3;
47 const size_t fnz_h0 = 2;
48 const size_t fnz_w0 = 1;
49 const size_t fz_n0 = 1;
50 const size_t fz_ni = 2;
51 const size_t fz_c0 = 3;
HasShapeDynamic(const ShapeVector & shape_list)52 bool HasShapeDynamic(const ShapeVector &shape_list) {
53   return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t v) { return v == kShapeDimAny; });
54 }
55 
CalMaxShape(int64_t ori_val,int64_t new_val)56 inline int64_t CalMaxShape(int64_t ori_val, int64_t new_val) {
57   if (ori_val < 0) {
58     return kShapeDimAny;
59   }
60 
61   return new_val;
62 }
63 
64 template <typename T>
Gcd(T a,T b)65 T Gcd(T a, T b) {
66   if (b == 0) {
67     return 0;
68   }
69   T c = b;
70   while (a % b != 0) {
71     c = a % b;
72     a = b;
73     b = c;
74   }
75   return c;
76 }
77 
78 template <typename T>
Lcm(T a,T b)79 T Lcm(T a, T b) {
80   if (b == 0) {
81     return 0;
82   }
83   T ret = (a * b) / (Gcd(a, b));
84   return ret;
85 }
86 
87 template <typename T>
DivCeil(T n1,T n2)88 T DivCeil(T n1, T n2) {
89   if (n2 != 0) {
90     return (n1 + n2 - 1) / n2;
91   }
92   return 0;
93 }
94 
95 template <typename T>
CheckDims(const std::vector<T> & shape)96 bool CheckDims(const std::vector<T> &shape) {
97   if (shape.size() != kDim4) {
98     MS_LOG(ERROR) << "Host shape dims should be 4";
99     return false;
100   }
101   return true;
102 }
103 
GetCubeSizeByType(const TypeId & data_type)104 int64_t GetCubeSizeByType(const TypeId &data_type) {
105   if (C0_32.find(data_type) != C0_32.end()) {
106     return kCube32;
107   }
108   if (C0_64.find(data_type) != C0_64.end()) {
109     return kCube64;
110   }
111   return kCube16;
112 }
113 
PaddingRangeTo5dDefault(const RangePair & ori_range)114 RangePair PaddingRangeTo5dDefault(const RangePair &ori_range) {
115   RangePair dst_range(kNcdhw, std::pair<int64_t, int64_t>(1, 1));
116   switch (ori_range.size()) {
117     case N_ncdhw:
118       return ori_range;
119     case C_ncdhw:
120       dst_range[C_ncdhw] = ori_range[N_ncdhw];
121       break;
122     case D_ncdhw:
123       dst_range[C_ncdhw] = ori_range[N_ncdhw];
124       dst_range[D_ncdhw] = ori_range[C_ncdhw];
125       break;
126     case H_ncdhw:
127       dst_range[C_ncdhw] = ori_range[N_ncdhw];
128       dst_range[D_ncdhw] = ori_range[C_ncdhw];
129       dst_range[H_ncdhw] = ori_range[D_ncdhw];
130       break;
131     case W_ncdhw:
132       dst_range[C_ncdhw] = ori_range[N_ncdhw];
133       dst_range[D_ncdhw] = ori_range[C_ncdhw];
134       dst_range[H_ncdhw] = ori_range[D_ncdhw];
135       dst_range[W_ncdhw] = ori_range[H_ncdhw];
136       break;
137     default:
138       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected shape size: " << ori_range.size();
139   }
140   return dst_range;
141 }
142 
PaddingRangeTo5D(const RangePair & ori_range,const std::string & padding_str={""})143 RangePair PaddingRangeTo5D(const RangePair &ori_range, const std::string &padding_str = {""}) {
144   std::vector<Axis5D> padding_axis;
145   StringToAxisVector5D(padding_str, &padding_axis);
146   if (padding_axis.empty() || ori_range.size() > padding_axis.size()) {
147     return PaddingRangeTo5dDefault(ori_range);
148   }
149 
150   RangePair dst_range(kNcdhw, std::pair<int64_t, int64_t>(1, 1));
151   for (size_t index = 0; index < ori_range.size(); index++) {
152     dst_range[padding_axis[index]] = ori_range[index];
153   }
154   return dst_range;
155 }
156 
PaddingRangeTo4dDefault(const RangePair & ori_range)157 RangePair PaddingRangeTo4dDefault(const RangePair &ori_range) {
158   RangePair dst_range(kNchwDims, std::pair<int64_t, int64_t>(1, 1));
159   switch (ori_range.size()) {
160     case kN:
161       return dst_range;
162     case kC:
163       dst_range[kC] = ori_range[kN];
164       break;
165     case kH:
166       dst_range[kC] = ori_range[kN];
167       dst_range[kH] = ori_range[kC];
168       break;
169     case kW:
170       dst_range[kC] = ori_range[kN];
171       dst_range[kH] = ori_range[kC];
172       dst_range[kW] = ori_range[kH];
173       break;
174     case kNchwDims:
175       return ori_range;
176     default:
177       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected range size: " << ori_range.size();
178   }
179   return dst_range;
180 }
181 
PaddingRangeTo4D(const RangePair & ori_range,const std::string & padding_str={""})182 RangePair PaddingRangeTo4D(const RangePair &ori_range, const std::string &padding_str = {""}) {
183   std::vector<Axis> padding_axis;
184   StringToAxisVector4D(padding_str, &padding_axis);
185   if (padding_axis.empty() || ori_range.size() > padding_axis.size()) {
186     return PaddingRangeTo4dDefault(ori_range);
187   }
188 
189   RangePair dst_range(kNchwDims, std::pair<int64_t, int64_t>(1, 1));
190   for (size_t index = 0; index < ori_range.size(); index++) {
191     dst_range[padding_axis[index]] = ori_range[index];
192   }
193   return dst_range;
194 }
195 }  // namespace
196 
StringToAxisVector4D(const std::string & reshape_type_str,std::vector<Axis> * reshape_type_vec)197 void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
198   MS_EXCEPTION_IF_NULL(reshape_type_vec);
199   if (reshape_type_str.empty()) {
200     MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
201     return;
202   }
203   for (const auto &c : reshape_type_str) {
204     switch (c) {
205       case 'N':
206         reshape_type_vec->push_back(N);
207         break;
208       case 'C':
209         reshape_type_vec->push_back(C);
210         break;
211       case 'H':
212         reshape_type_vec->push_back(H);
213         break;
214       case 'W':
215         reshape_type_vec->push_back(W);
216         break;
217       default:
218         MS_LOG(INTERNAL_EXCEPTION) << "Unknown axis " << c << "in reshape type.";
219     }
220   }
221 }
222 
StringToAxisVector5D(const std::string & reshape_type_str,std::vector<Axis5D> * reshape_type_vec)223 void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
224   MS_EXCEPTION_IF_NULL(reshape_type_vec);
225   if (reshape_type_str.empty()) {
226     MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
227     return;
228   }
229   for (const auto &c : reshape_type_str) {
230     switch (c) {
231       case 'N':
232         reshape_type_vec->push_back(N_ncdhw);
233         break;
234       case 'C':
235         reshape_type_vec->push_back(C_ncdhw);
236         break;
237       case 'D':
238         reshape_type_vec->push_back(D_ncdhw);
239         break;
240       case 'H':
241         reshape_type_vec->push_back(H_ncdhw);
242         break;
243       case 'W':
244         reshape_type_vec->push_back(W_ncdhw);
245         break;
246       default:
247         MS_LOG(INTERNAL_EXCEPTION) << "Unknown axis " << c << "in reshape type.";
248     }
249   }
250 }
251 
IsNeedPadding(const std::string & format,const ShapeVector & shape)252 bool IsNeedPadding(const std::string &format, const ShapeVector &shape) {
253   if (shape.empty()) {
254     return false;
255   }
256   if (IsDynamicRank(shape) && !IsOneOfDynRankNeedPadShape(format)) {
257     return false;
258   }
259   if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW || IsOneOfNoPaddingFormat(format)) {
260     return false;
261   } else if (shape.size() < kDim4) {
262     return true;
263   }
264   return false;
265 }
266 
GetRuntimePaddingShape(const AnfNodePtr & node,size_t index)267 ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
268   MS_EXCEPTION_IF_NULL(node);
269   ShapeVector host_shape;
270   if (node->isa<ValueNode>()) {
271     auto value_node = node->cast<ValueNodePtr>();
272     MS_EXCEPTION_IF_NULL(value_node);
273     auto node_value = value_node->value();
274     MS_EXCEPTION_IF_NULL(node_value);
275     // Scalar has no shape.
276     if (node_value->isa<Scalar>()) {
277       return {};
278     }
279     if (node_value->isa<StringImm>()) {
280       auto string_value = node_value->cast<StringImmPtr>();
281       MS_EXCEPTION_IF_NULL(string_value);
282       return {SizeToLong(string_value->ToString().size())};
283     }
284     if (node_value->isa<ValueSequence>()) {
285       MS_LOG(INFO) << "GetRuntimePaddingShape does not support the value sequence for value node:"
286                    << node->fullname_with_scope() << ", debug name:" << node->DebugString();
287       return {0};
288     }
289     auto tensor = node_value->cast<tensor::TensorPtr>();
290     if (tensor == nullptr) {
291       MS_LOG(INTERNAL_EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
292     }
293     host_shape = tensor->shape();
294   } else {
295     host_shape = common::AnfAlgo::GetOutputInferShape(node, index);
296   }
297   auto format = AnfAlgo::GetOutputFormat(node, index);
298   if (IsNeedPadding(format, host_shape)) {
299     host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index), node);
300   }
301   return host_shape;
302 }
303 
TransDataType(const TypeIdArgs & args,void * result)304 bool TransDataType(const TypeIdArgs &args, void *result) {
305   DataTypeTransfer dataTypeTransfer;
306   return dataTypeTransfer.TransDataType(args, result);
307 }
308 
TransFormat(const FormatArgs & args,void * result,const AnfNodePtr & node,size_t index)309 bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index) {
310   FormatTransfer formatTransfer;
311   return formatTransfer.TransDataByFormat(args, result, node, index, true);
312 }
313 
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,int64_t groups)314 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
315   FormatTransfer formatTransfer;
316   return formatTransfer.TransDataBackwordCore(args, result, groups);
317 }
318 
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,const AnfNodePtr & node,size_t index)319 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index) {
320   FormatTransfer formatTransfer;
321   return formatTransfer.TransDataByFormat(args, result, node, index, false);
322 }
323 
324 /**###################### DATA TYPE TRANS ################################*/
CheckMemSize(const TypeIdArgs & args)325 void CheckMemSize(const TypeIdArgs &args) {
326   auto src_type_size = abstract::TypeIdSize(args.src_data_type);
327   auto dst_type_size = abstract::TypeIdSize(args.dst_data_type);
328   if (src_type_size < 1 || dst_type_size < 1) {
329     MS_LOG(INTERNAL_EXCEPTION) << "Invalid src or dst data type. Src type: " << TypeIdLabel(args.src_data_type)
330                                << ", dst type: " << TypeIdLabel(args.dst_data_type);
331   }
332   if (SizeToLong(args.data_size / src_type_size) != args.src_shape_size) {
333     MS_LOG(INTERNAL_EXCEPTION) << "Invalid src or dst data  shape size. Src shape size: " << args.src_shape_size
334                                << ", dst shape size: " << args.data_size / src_type_size;
335   }
336 }
337 
338 template <typename SrcT, typename DstT>
TransDataSrc2Dst(const TypeIdArgs & args,void * dst,const int64_t data_size)339 void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const int64_t data_size) {
340   CheckMemSize(args);
341   for (int64_t idx = 0; idx != data_size; idx++) {
342     SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
343     static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
344   }
345 }
346 template <typename SrcT>
TransDataSrc2Fp16(const TypeIdArgs & args,void * dst,const int64_t data_size)347 void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const int64_t data_size) {
348   CheckMemSize(args);
349   auto src_data = static_cast<const SrcT *>(args.data);
350   auto half_data = static_cast<float16 *>(dst);
351   for (int64_t i = 0; i < data_size; i++) {
352     half_data[i] = float16(src_data[i]);
353   }
354 }
355 
CastKernel(const TypeIdArgs & args,void * dst,int64_t data_size,DataTypeTransMode mode) const356 bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const {
357   using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const int64_t)>;
358   const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
359     {DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
360     {DataTypeTransMode::FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
361     {DataTypeTransMode::FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
362     {DataTypeTransMode::FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
363     {DataTypeTransMode::FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
364     {DataTypeTransMode::FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
365     {DataTypeTransMode::FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
366     {DataTypeTransMode::FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
367     {DataTypeTransMode::FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
368     {DataTypeTransMode::FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
369     {DataTypeTransMode::FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
370     {DataTypeTransMode::FROM_INT16_TO_INT32, TransDataSrc2Dst<int16_t, int32_t>},
371     {DataTypeTransMode::FROM_INT16_TO_INT64, TransDataSrc2Dst<int16_t, int64_t>},
372     {DataTypeTransMode::FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
373     {DataTypeTransMode::FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
374     {DataTypeTransMode::FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
375     {DataTypeTransMode::FROM_INT32_TO_INT16, TransDataSrc2Dst<int32_t, int16_t>},
376     {DataTypeTransMode::FROM_INT32_TO_UINT16, TransDataSrc2Dst<int32_t, uint16_t>},
377     {DataTypeTransMode::FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
378     {DataTypeTransMode::FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
379     {DataTypeTransMode::FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
380     {DataTypeTransMode::FROM_INT64_TO_INT16, TransDataSrc2Dst<int64_t, int16_t>},
381     {DataTypeTransMode::FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
382     {DataTypeTransMode::FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
383     {DataTypeTransMode::FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
384     {DataTypeTransMode::FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
385     {DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
386     {DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>}};
387 
388   if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
389     device::FloatToHalf(dst, args.data, LongToSize(data_size));
390     return true;
391   } else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
392     device::HalfToFloat(dst, args.data, LongToSize(data_size));
393     return true;
394   }
395   auto iter = cast_kernel_map.find(mode);
396   if (iter != cast_kernel_map.end()) {
397     iter->second(args, dst, data_size);
398     return true;
399   } else {
400     MS_LOG(ERROR) << "Can not find a datatype trans function. Src type :" << TypeIdLabel(args.src_data_type)
401                   << ", dst_type:" << TypeIdLabel(args.dst_data_type);
402     return false;
403   }
404 }
405 
TransDataType(const TypeIdArgs & args,void * result) const406 bool DataTypeTransfer::TransDataType(const TypeIdArgs &args, void *result) const {
407   MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_data_type) << " to "
408                 << TypeIdLabel(args.dst_data_type);
409   MS_EXCEPTION_IF_NULL(result);
410   std::pair<TypeId, TypeId> type_info(args.src_data_type, args.dst_data_type);
411   auto iter = mode_map.find(type_info);
412   if (iter == mode_map.end()) {
413     MS_LOG(ERROR) << "Can not find a datatype trans type. src_type :" << TypeIdLabel(args.src_data_type)
414                   << ", dst_type:" << TypeIdLabel(args.dst_data_type);
415     return false;
416   }
417   auto trans_mode = iter->second;
418   if (!CastKernel(args, result, args.src_shape_size, trans_mode)) {
419     MS_LOG(ERROR) << "Failed to trans datatype. Src: " << TypeIdLabel(args.src_data_type)
420                   << ", dst: " << TypeIdLabel(args.dst_data_type);
421     return false;
422   }
423   return true;
424 }
425 
426 /**###################### DATA SHAPE TRANS ################################*/
GetDeviceShapeByFormat(const ShapeVector & shape,const std::string & format,const AnfNodePtr & node,size_t index,const TypeId & type,bool is_output) const427 ShapeVector DeviceShapeTransfer::GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format,
428                                                         const AnfNodePtr &node, size_t index, const TypeId &type,
429                                                         bool is_output) const {
430   auto dev_shape = GetFixedDeviceShape(shape, node, index, is_output);
431   if (dev_shape.has_value()) {
432     return dev_shape.value();
433   }
434   int64_t groups = 1;
435   if (format == kOpFormat_FRAC_Z) {
436     groups = common::AnfAlgo::GetAttrGroups(node, index);
437   }
438   ShapeVector input_hidden_size = {kAlign16, kAlign16};
439   if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
440     input_hidden_size = GetAttrInputAndHiddenSize(node);
441   }
442   if (node != nullptr) {
443     MS_LOG(DEBUG) << "Start trans infer shape to device shape for node: " << node->DebugString()
444                   << ", format: " << format;
445   }
446   return TransCore(shape, format, type, groups, input_hidden_size);
447 }
448 
GetDeviceShapeByFormat(const ShapeVector & shape,const std::string & format,const TypeId & type,int64_t groups,const ShapeVector & input_hidden_size) const449 ShapeVector DeviceShapeTransfer::GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format,
450                                                         const TypeId &type, int64_t groups,
451                                                         const ShapeVector &input_hidden_size) const {
452   return TransCore(shape, format, type, groups, input_hidden_size);
453 }
454 
GetFixedDeviceShape(const ShapeVector &,const AnfNodePtr & node,size_t index,bool is_output) const455 std::optional<ShapeVector> DeviceShapeTransfer::GetFixedDeviceShape(const ShapeVector &, const AnfNodePtr &node,
456                                                                     size_t index, bool is_output) const {
457   if (node == nullptr || !node->isa<CNode>()) {
458     return {};
459   }
460   auto attr_name = is_output ? kAttrFixedOutputDeviceShape : kAttrFixedInputDeviceShape;
461   auto cnode = node->cast<CNodePtr>();
462   if (!common::AnfAlgo::HasNodeAttr(attr_name, cnode)) {
463     return {};
464   }
465 
466   auto shapes = common::AnfAlgo::GetNodeAttr<std::vector<ShapeVector>>(cnode, attr_name);
467   if (index >= shapes.size()) {
468     MS_LOG(INFO) << "Index is out of range, got index: " << index << ", shape size: " << shapes.size();
469     return {};
470   }
471   return std::optional<ShapeVector>(std::move(shapes[index]));
472 }
473 
TransCore(const ShapeVector & shape,const std::string & format,const TypeId & type,int64_t groups,const ShapeVector & input_hidden_size) const474 ShapeVector DeviceShapeTransfer::TransCore(const ShapeVector &shape, const std::string &format, const TypeId &type,
475                                            int64_t groups, const ShapeVector &input_hidden_size) const {
476   using DeviceShapeTransferFunc = std::function<ShapeVector(const ShapeVector &, const TypeId &)>;
477   static const mindspore::HashMap<std::string, DeviceShapeTransferFunc> device_shape_map = {
478     {kOpFormat_NCHW, NCHWDeviceShape},
479     {kOpFormat_NHWC, NHWCDeviceShape},
480     {kOpFormat_HWCN, HWCNDeviceShape},
481     {kOpFormat_NCDHW, NCDHWDeviceShape},
482     {kOpFormat_FRAC_Z, FRAC_ZDeviceShape},
483     {kOpFormat_FRAC_NZ, FRAC_NZDeviceShape},
484     {kOpFormat_NC1HWC0, NC1HWC0DeviceShape},
485     {kOpFormat_NDC1HWC0, NDC1HWC0DeviceShape},
486     {kOpFormat_C1HWNCoC0, C1HWNCOC0DeviceShape},
487     {kOpFormat_NC1HWC0_C04, NC1HWC04DeviceShape},
488     {kOpFormat_FRACTAL_Z_3D, FRAC_Z3DDeviceShape},
489     {kOpFormat_FRACTAL_Z_C04, FRAC_ZC04DeviceShape},
490     {kOpFormat_ChannelLast, ChannelLastDeviceShape},
491     {kOpFormat_FRACTAL_ZN_LSTM, FRAC_ZN_LSTMDeviceShape}};
492   if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
493     return shape;
494   }
495   if (groups > 1 && format == kOpFormat_FRAC_Z) {
496     return FRAC_ZDeviceShapeWithGroups(shape, type, groups);
497   }
498   if (format == kOpFormat_FRACTAL_ZN_RNN) {
499     return FRAC_ZN_RNNDeviceShape(shape, type, input_hidden_size);
500   }
501   if (format == kOpFormat_ND_RNN_BIAS) {
502     return NDRNNBiasDeviceShape(shape, type, input_hidden_size[1]);
503   }
504   auto temp_shape = shape;
505   if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && shape.size() < kDim4 &&
506       !IsOneOf3DFormat(format)) {
507     MS_LOG(INFO) << "Origin shape size is less than 4, should be Padding shape by Default firstly";
508     temp_shape = PaddingShapeTo4dDefault(shape);
509   }
510   if (shape.size() != kDim5 && IsOneOf3DFormat(format)) {
511     temp_shape = PaddingShapeTo5dDefault(shape);
512   }
513   auto iter = device_shape_map.find(format);
514   if (iter == device_shape_map.end()) {
515     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected format[" << format << "]";
516   }
517   return iter->second(temp_shape, type);
518 }
519 
NCHWDeviceShape(const ShapeVector & shape,const TypeId &)520 ShapeVector DeviceShapeTransfer::NCHWDeviceShape(const ShapeVector &shape, const TypeId &) {
521   if (!CheckDims(shape)) {
522     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
523   }
524   return shape;
525 }
526 
NHWCDeviceShape(const ShapeVector & shape,const TypeId &)527 ShapeVector DeviceShapeTransfer::NHWCDeviceShape(const ShapeVector &shape, const TypeId &) {
528   if (!CheckDims(shape)) {
529     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
530   }
531   ShapeVector device_shape;
532   device_shape.push_back(shape[kN]);
533   device_shape.push_back(shape[kH]);
534   device_shape.push_back(shape[kW]);
535   device_shape.push_back(shape[kC]);
536   return device_shape;
537 }
538 
HWCNDeviceShape(const ShapeVector & shape,const TypeId &)539 ShapeVector DeviceShapeTransfer::HWCNDeviceShape(const ShapeVector &shape, const TypeId &) {
540   if (!CheckDims(shape)) {
541     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
542   }
543   ShapeVector device_shape;
544   device_shape.push_back(shape[kH]);
545   device_shape.push_back(shape[kW]);
546   device_shape.push_back(shape[kC]);
547   device_shape.push_back(shape[kN]);
548   return device_shape;
549 }
550 
FRAC_ZDeviceShape(const ShapeVector & shape,const TypeId & type)551 ShapeVector DeviceShapeTransfer::FRAC_ZDeviceShape(const ShapeVector &shape, const TypeId &type) {
552   if (!CheckDims(shape)) {
553     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
554   }
555   ShapeVector device_shape;
556   auto c0 = GetCubeSizeByType(type);
557   if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) {
558     device_shape.push_back(abstract::Shape::kShapeDimAny);
559   } else {
560     auto c1 = (shape[kC] + c0 - 1) / c0;
561     device_shape.push_back(shape[kH] * shape[kW] * c1);
562   }
563   if (shape[kN] == abstract::Shape::kShapeDimAny) {
564     device_shape.push_back(abstract::Shape::kShapeDimAny);
565   } else {
566     auto no = (shape[kN] + kNiSize - 1) / kNiSize;
567     device_shape.push_back(no);
568   }
569   device_shape.push_back(kNiSize);
570   device_shape.push_back(c0);
571   return device_shape;
572 }
573 
NC1HWC0DeviceShape(const ShapeVector & shape,const TypeId & type)574 ShapeVector DeviceShapeTransfer::NC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
575   if (!CheckDims(shape)) {
576     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
577   }
578   ShapeVector device_shape;
579   auto c0 = GetCubeSizeByType(type);
580   auto c1 = (shape[kC] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (shape[kC] + c0 - 1) / c0;
581   device_shape.push_back(shape[kN]);
582   device_shape.push_back(c1);
583   device_shape.push_back(shape[kH]);
584   device_shape.push_back(shape[kW]);
585   device_shape.push_back(c0);
586   return device_shape;
587 }
588 
NDC1HWC0DeviceShape(const ShapeVector & shape,const TypeId & type)589 ShapeVector DeviceShapeTransfer::NDC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
590   if (shape.size() == kDim6) {
591     return shape;
592   }
593   if (shape.size() != kDim5) {
594     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
595   }
596   ShapeVector device_shape;
597   auto c0 = GetCubeSizeByType(type);
598   auto c1 = (shape[1] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (shape[1] + c0 - 1) / c0;
599   device_shape.push_back(shape[N_ncdhw]);
600   device_shape.push_back(shape[D_ncdhw]);
601   device_shape.push_back(c1);
602   device_shape.push_back(shape[H_ncdhw]);
603   device_shape.push_back(shape[W_ncdhw]);
604   device_shape.push_back(c0);
605   return device_shape;
606 }
607 
FRAC_Z3DDeviceShape(const ShapeVector & shape,const TypeId & type)608 ShapeVector DeviceShapeTransfer::FRAC_Z3DDeviceShape(const ShapeVector &shape, const TypeId &type) {
609   if (shape.size() != kDim5) {
610     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
611   }
612   ShapeVector device_shape;
613   auto c0 = GetCubeSizeByType(type);
614   if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
615     device_shape.push_back(abstract::Shape::kShapeDimAny);
616   } else {
617     auto c1 = (shape[1] + c0 - 1) / c0;
618     device_shape.push_back(shape[D_ncdhw] * c1 * shape[H_ncdhw] * shape[W_ncdhw]);
619   }
620   auto no =
621     (shape[0] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (shape[0] + kNiSize - 1) / kNiSize;
622   device_shape.push_back(no);
623   device_shape.push_back(kNiSize);
624   device_shape.push_back(c0);
625   return device_shape;
626 }
627 
C1HWNCOC0DeviceShape(const ShapeVector & shape,const TypeId & type)628 ShapeVector DeviceShapeTransfer::C1HWNCOC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
629   if (!CheckDims(shape)) {
630     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
631   }
632   ShapeVector device_shape;
633   auto c0 = GetCubeSizeByType(type);
634   if (shape[kC] == abstract::Shape::kShapeDimAny) {
635     device_shape.push_back(abstract::Shape::kShapeDimAny);
636   } else {
637     device_shape.push_back((shape[kC] - 1) / c0 + 1);
638   }
639   device_shape.push_back(shape[kH]);
640   device_shape.push_back(shape[kW]);
641   device_shape.push_back(shape[kN]);
642   device_shape.push_back(c0);
643   device_shape.push_back(c0);
644   return device_shape;
645 }
646 
FRAC_ZC04DeviceShape(const ShapeVector & shape,const TypeId &)647 ShapeVector DeviceShapeTransfer::FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &) {
648   if (!CheckDims(shape)) {
649     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
650   }
651   ShapeVector device_shape;
652   const int64_t C04 = 4;
653   int64_t first_dim;
654   if (HasShapeDynamic({shape[kH], shape[kW]})) {
655     first_dim = abstract::Shape::kShapeDimAny;
656   } else {
657     first_dim = DivCeil(C04 * shape[kH] * shape[kW], kCubeSize);
658   }
659   auto no =
660     (shape[kN] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : DivCeil(shape.at(kN), kCubeSize);
661   device_shape.push_back(first_dim);
662   device_shape.push_back(no);
663   device_shape.push_back(kCubeSize);
664   device_shape.push_back(kCubeSize);
665   return device_shape;
666 }
667 
NC1HWC04DeviceShape(const ShapeVector & shape,const TypeId &)668 ShapeVector DeviceShapeTransfer::NC1HWC04DeviceShape(const ShapeVector &shape, const TypeId &) {
669   if (!CheckDims(shape)) {
670     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
671   }
672   ShapeVector device_shape;
673   const int64_t C04 = 4;
674   const int64_t C1 =
675     (shape[kC] == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : DivCeil(shape.at(kC), C04);
676   device_shape.push_back(shape[kN]);
677   device_shape.push_back(C1);
678   device_shape.push_back(shape[kH]);
679   device_shape.push_back(shape[kW]);
680   device_shape.push_back(C04);
681   return device_shape;
682 }
683 
NCDHWDeviceShape(const ShapeVector & shape,const TypeId &)684 ShapeVector DeviceShapeTransfer::NCDHWDeviceShape(const ShapeVector &shape, const TypeId &) {
685   if (shape.size() < kDim5) {
686     MS_LOG(INTERNAL_EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
687   }
688   return shape;
689 }
690 
ChannelLastDeviceShape(const ShapeVector & shape,const TypeId &)691 ShapeVector DeviceShapeTransfer::ChannelLastDeviceShape(const ShapeVector &shape, const TypeId &) {
692   auto dim = shape.size();
693   ShapeVector axis;
694   axis.resize(dim);
695   const int step_value = 2;
696   std::iota(axis.begin() + 1, axis.end(), step_value);
697   axis[dim - 1] = 1;
698   ShapeVector device_shape;
699   (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
700                        [&shape](size_t n) { return shape[n]; });
701   return device_shape;
702 }
703 
FRAC_NZDeviceShape(const ShapeVector & shape,const TypeId & type)704 ShapeVector DeviceShapeTransfer::FRAC_NZDeviceShape(const ShapeVector &shape, const TypeId &type) {
705   ShapeVector device_shape;
706   auto c0 = GetCubeSizeByType(type);
707   if (shape.size() == 1 && (shape[0] == 1 || shape[0] % c0 == 0)) {
708     // For [1] and [1024] shape we can trait it as NZ shape
709     return shape;
710   }
711   if (shape.size() == 1) {
712     device_shape.push_back(DivCeil(shape[0], c0));
713     device_shape.push_back(1);
714     device_shape.push_back(kCubeSize);
715     device_shape.push_back(c0);
716     return device_shape;
717   } else {
718     const auto remove_dim = 2;
719     (void)std::copy(shape.begin(), shape.end() - remove_dim, std::back_inserter(device_shape));
720   }
721   int64_t h_shape = shape[shape.size() - kH];
722   int64_t w_shape = shape[shape.size() - 1];
723   int64_t w1 = (w_shape == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (w_shape - 1) / c0 + 1;
724   int64_t h1 =
725     (h_shape == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (h_shape - 1) / kCubeSize + 1;
726   device_shape.push_back(w1);
727   device_shape.push_back(h1);
728   device_shape.push_back(kCubeSize);
729   device_shape.push_back(c0);
730   return device_shape;
731 }
732 
FRAC_ZN_LSTMDeviceShape(const ShapeVector & shape,const TypeId &)733 ShapeVector DeviceShapeTransfer::FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &) {
734   ShapeVector device_shape;
735   const int64_t lstm_ni = 4;
736   const int64_t ni = 16;
737   int64_t first = abstract::Shape::kShapeDimAny;
738   int64_t second = abstract::Shape::kShapeDimAny;
739   if (!HasShapeDynamic({shape[kN], shape[kC]})) {
740     const int64_t h = shape.at(kN) / lstm_ni;
741     const int64_t i = shape.at(kC) - h;
742     first = DivCeil(i, ni) + DivCeil(h, ni);
743     second = lstm_ni * DivCeil(h, ni);
744   }
745   device_shape.push_back(first);
746   device_shape.push_back(second);
747   device_shape.push_back(ni);
748   device_shape.push_back(ni);
749   return device_shape;
750 }
751 
FRAC_ZDeviceShapeWithGroups(const ShapeVector & shape,const TypeId & type,int64_t groups)752 ShapeVector DeviceShapeTransfer::FRAC_ZDeviceShapeWithGroups(const ShapeVector &shape, const TypeId &type,
753                                                              int64_t groups) {
754   if (!CheckDims(shape)) {
755     MS_LOG(INTERNAL_EXCEPTION) << "Check dims failed.";
756   }
757   if (groups <= 0) {
758     MS_LOG(INTERNAL_EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
759   }
760   auto cube_size = GetCubeSizeByType(type);
761   auto c1_dim = abstract::Shape::kShapeDimAny;
762   auto g_dim = abstract::Shape::kShapeDimAny;
763   auto n1 = abstract::Shape::kShapeDimAny;
764   if (shape.size() < kShape2dDims) {
765     MS_LOG(INTERNAL_EXCEPTION) << "Format FRAC_ZDeviceShape don't support shape with " << shape.size() << " dims";
766   }
767   if (!HasShapeDynamic({shape[kC], shape[kN]})) {
768     auto group_size = groups;
769     auto cin_ori_tmp = static_cast<int64_t>(shape[kC]);
770     auto cout_ori_tmp = static_cast<int64_t>(shape[kN]) / group_size;
771     auto e_mult =
772       std::min(Lcm(Lcm(cin_ori_tmp, cube_size) / cin_ori_tmp, Lcm(cout_ori_tmp, cube_size) / cout_ori_tmp), group_size);
773     auto cin_opt = DivCeil(e_mult * cin_ori_tmp, cube_size) * cube_size;
774     c1_dim = cin_opt / cube_size;
775     g_dim = DivCeil(group_size, e_mult);
776     n1 = DivCeil(cout_ori_tmp * e_mult, cube_size);
777   }
778   ShapeVector device_shape;
779   if (!HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) {
780     device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
781   } else {
782     device_shape.push_back(abstract::Shape::kShapeDimAny);
783   }
784   device_shape.push_back(n1);
785   device_shape.push_back(kNiSize);
786   device_shape.push_back(cube_size);
787   return device_shape;
788 }
789 
FRAC_ZN_RNNDeviceShape(const ShapeVector & shape,const TypeId & type,const ShapeVector & input_hidden_size)790 ShapeVector DeviceShapeTransfer::FRAC_ZN_RNNDeviceShape(const ShapeVector &shape, const TypeId &type,
791                                                         const ShapeVector &input_hidden_size) {
792   if (shape.size() < kShape2dDims) {
793     MS_LOG(INTERNAL_EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
794   }
795   auto C0 = GetCubeSizeByType(type);
796   auto input_size = input_hidden_size[0];
797   auto hidden_size = input_hidden_size[1];
798   auto dim_last1 = shape[shape.size() - 1];
799   auto dim_last2 = shape[shape.size() - kDim2];
800   const int64_t NUM16 = 16;
801 
802   ShapeVector device_shape = shape;
803   if (dim_last2 == abstract::Shape::kShapeDimAny) {
804     device_shape[shape.size() - kDim2] = abstract::Shape::kShapeDimAny;
805   } else if (dim_last2 == input_size || dim_last2 == hidden_size) {
806     device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
807   } else if (dim_last2 == input_size + hidden_size) {
808     device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
809   } else {
810     MS_LOG(INTERNAL_EXCEPTION) << "The second-last dim value of shape is invalid.";
811   }
812   if (dim_last1 == abstract::Shape::kShapeDimAny) {
813     device_shape[shape.size() - kDim1] = abstract::Shape::kShapeDimAny;
814   } else {
815     if (dim_last1 % hidden_size != 0) {
816       MS_LOG(INTERNAL_EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size "
817                                  << hidden_size;
818     }
819     int64_t n_num = shape[shape.size() - 1] / hidden_size;
820     device_shape[shape.size() - kDim1] = n_num * DivCeil(hidden_size, C0);
821   }
822   device_shape.push_back(NUM16);
823   device_shape.push_back(C0);
824   return device_shape;
825 }
826 
NDRNNBiasDeviceShape(const ShapeVector & shape,const TypeId & type,int64_t hidden_size)827 ShapeVector DeviceShapeTransfer::NDRNNBiasDeviceShape(const ShapeVector &shape, const TypeId &type,
828                                                       int64_t hidden_size) {
829   if (shape.empty()) {
830     MS_LOG(INTERNAL_EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
831   }
832   auto C0 = GetCubeSizeByType(type);
833   ShapeVector device_shape = shape;
834   // cppcheck-suppress *
835   auto dim_last1 = shape[shape.size() - 1];
836   if (dim_last1 == abstract::Shape::kShapeDimAny) {
837     device_shape[shape.size() - 1] = abstract::Shape::kShapeDimAny;
838   } else {
839     if (hidden_size <= 0 || dim_last1 % hidden_size != 0) {
840       MS_LOG(INTERNAL_EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size "
841                                  << hidden_size;
842     }
843     int64_t n_num = shape[shape.size() - 1] / hidden_size;
844     device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
845   }
846   return device_shape;
847 }
848 
GetAttrInputAndHiddenSize(const AnfNodePtr & node) const849 ShapeVector DeviceShapeTransfer::GetAttrInputAndHiddenSize(const AnfNodePtr &node) const {
850   MS_EXCEPTION_IF_NULL(node);
851   std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
852   if (!node->isa<CNode>() && !node->isa<Parameter>()) {
853     return input_hidden_size;
854   }
855 
856   if (node->isa<Parameter>()) {
857     auto param = node->cast<ParameterPtr>();
858     input_hidden_size[0] = param->input_size();
859     input_hidden_size[1] = param->hidden_size();
860   } else {
861     CNodePtr cnode = node->cast<CNodePtr>();
862     if (cnode == nullptr || !common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) ||
863         !common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
864       MS_LOG(INTERNAL_EXCEPTION)
865         << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
866         << node->DebugString();
867     }
868     input_hidden_size[0] = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrInputSize);
869     input_hidden_size[1] = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrHiddenSize);
870   }
871   return input_hidden_size;
872 }
873 
874 /**###################### DATA FORMAT TRANS ################################*/
SetData(int64_t size,bool pad_zero,int64_t src_idx,int64_t dst_idx,const FormatArgs & args,void * result)875 inline void SetData(int64_t size, bool pad_zero, int64_t src_idx, int64_t dst_idx, const FormatArgs &args,
876                     void *result) {
877   switch (size) {
878     case b1:
879       static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
880       break;
881     case b2:
882       static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
883       break;
884     case b4:
885       static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
886       break;
887     case b8:
888       static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
889       break;
890     default:
891       MS_LOG(EXCEPTION) << "Trans data not support size " << size;
892   }
893 }
894 
TransDataByFormat(const FormatArgs & args,void * result,const AnfNodePtr & node,size_t index,bool is_forward)895 bool FormatTransfer::TransDataByFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index,
896                                        bool is_forward) {
897   int64_t groups = 1;
898   if (args.device_format == kOpFormat_FRAC_Z && node != nullptr) {
899     groups = common::AnfAlgo::GetAttrGroups(node, index);
900   }
901   if (is_forward) {
902     return TransDataForwardCore(args, result, groups);
903   }
904   return TransDataBackwordCore(args, result, groups);
905 }
906 
TransDataForwardCore(const FormatArgs & args,void * result,int64_t groups)907 bool FormatTransfer::TransDataForwardCore(const FormatArgs &args, void *result, int64_t groups) {
908   MS_LOG(DEBUG) << "Start trans format.";
909   if (abstract::TypeIdSize(args.src_data_type) < 1) {
910     MS_LOG(ERROR) << "Invalid datatype: " << args.src_data_type;
911     return false;
912   }
913   if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
914     return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, true, groups);
915   }
916   auto iter = format_trans_fp_map.find(args.device_format);
917   if (iter == format_trans_fp_map.end()) {
918     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected format[" << args.device_format << "]";
919   }
920   return iter->second(args, result);
921 }
922 
TransDataBackwordCore(const FormatArgs & args,void * result,int64_t groups)923 bool FormatTransfer::TransDataBackwordCore(const FormatArgs &args, void *result, int64_t groups) {
924   MS_LOG(DEBUG) << "Start trans format.";
925   if (abstract::TypeIdSize(args.src_data_type) < 1) {
926     MS_LOG(ERROR) << "Invalid datatype, type: " << args.src_data_type;
927     return false;
928   }
929   if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
930     return FRAC_Z_TO_NCHW_WITH_GROUPS(args, result, groups);
931   }
932   auto iter = format_trans_bp_map.find(args.device_format);
933   if (iter == format_trans_bp_map.end()) {
934     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected format[" << args.device_format << "]";
935   }
936   return iter->second(args, result);
937 }
938 
CheckArgs(const FormatArgs & args,int64_t * size)939 bool FormatTransfer::CheckArgs(const FormatArgs &args, int64_t *size) {
940   if (args.host_shape.size() != kDim4) {
941     MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
942     return false;
943   }
944   MS_EXCEPTION_IF_NULL(size);
945   *size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
946   if (*size < 1) {
947     MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
948     return false;
949   }
950   auto total_size = abstract::ShapeSize(args.device_shape) * (*size);
951   if (total_size != SizeToLong(args.device_size)) {
952     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
953     return false;
954   }
955   return true;
956 }
957 
TransShapeToHW_NZ(const ShapeVector & host_shape,ShapeVector * hw_shape)958 bool FormatTransfer::TransShapeToHW_NZ(const ShapeVector &host_shape, ShapeVector *hw_shape) {
959   MS_EXCEPTION_IF_NULL(hw_shape);
960   if (host_shape.empty()) {
961     MS_LOG(ERROR) << "Size of vector is 0.";
962     return false;
963   }
964   switch (host_shape.size()) {
965     case 1:
966       hw_shape->push_back(1);
967       hw_shape->push_back(1);
968       hw_shape->push_back(host_shape[0]);
969       return true;
970     default:
971       auto size = host_shape.size();
972       if (size < kDim2) {
973         MS_LOG(ERROR) << "Illegal size: " << size;
974         return false;
975       }
976       int64_t times = 1;
977       for (size_t i = 0; i != size - kDim2; i++) {
978         times *= host_shape[i];
979       }
980       hw_shape->push_back(times);
981       hw_shape->push_back(host_shape[size - kDim2]);
982       hw_shape->push_back(host_shape[size - kDim1]);
983       return true;
984   }
985 }
986 
NCHW_TO_4D(const FormatArgs & args,void * result)987 bool FormatTransfer::NCHW_TO_4D(const FormatArgs &args, void *result) {
988   // trans nchw to NHWC or HWCN
989   MS_LOG(DEBUG) << "Trans format from nchw to " << args.device_format;
990   MS_EXCEPTION_IF_NULL(result);
991   int64_t size = 0;
992   if (!CheckArgs(args, &size)) {
993     MS_LOG(ERROR) << "Check args failed.";
994     return false;
995   }
996   auto n = args.host_shape[kN];
997   auto c = args.host_shape[kC];
998   auto h = args.host_shape[kH];
999   auto w = args.host_shape[kW];
1000   for (int64_t ni = 0; ni < n; ni++) {
1001     for (int64_t ci = 0; ci < c; ci++) {
1002       for (int64_t hi = 0; hi < h; hi++) {
1003         for (int64_t wi = 0; wi < w; wi++) {
1004           auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1005           int64_t dst_idx = 0;
1006           if (args.device_format == kOpFormat_NHWC) {
1007             dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1008           } else if (args.device_format == kOpFormat_HWCN) {
1009             dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1010           }
1011           SetData(size, false, src_idx, dst_idx, args, result);
1012         }
1013       }
1014     }
1015   }
1016   return true;
1017 }
1018 
TO_NCHW(const FormatArgs & args,void * result)1019 bool FormatTransfer::TO_NCHW(const FormatArgs &args, void *result) {
1020   MS_LOG(DEBUG) << "Trans format to nchw from " << args.device_format;
1021   MS_EXCEPTION_IF_NULL(result);
1022   int64_t size = 0;
1023   if (!CheckArgs(args, &size)) {
1024     MS_LOG(ERROR) << "Check args failed.";
1025     return false;
1026   }
1027   auto n = args.host_shape[kN];
1028   auto c = args.host_shape[kC];
1029   auto h = args.host_shape[kH];
1030   auto w = args.host_shape[kW];
1031   for (int64_t ni = 0; ni < n; ni++) {
1032     for (int64_t ci = 0; ci < c; ci++) {
1033       for (int64_t hi = 0; hi < h; hi++) {
1034         for (int64_t wi = 0; wi < w; wi++) {
1035           auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1036           int64_t src_idx = 0;
1037           if (args.device_format == kOpFormat_NHWC) {
1038             src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1039           } else if (args.device_format == kOpFormat_HWCN) {
1040             src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1041           }
1042           SetData(size, false, src_idx, dst_idx, args, result);
1043         }
1044       }
1045     }
1046   }
1047   return true;
1048 }
1049 
NCHW_TO_FRAC_Z(const FormatArgs & args,void * result)1050 bool FormatTransfer::NCHW_TO_FRAC_Z(const FormatArgs &args, void *result) {
1051   MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
1052   MS_EXCEPTION_IF_NULL(result);
1053   auto size = Common4DCheck(args);
1054   auto n = args.host_shape[kN];
1055   auto c = args.host_shape[kC];
1056   auto h = args.host_shape[kH];
1057   auto w = args.host_shape[kW];
1058   auto c0 = GetCubeSizeByType(args.src_data_type);
1059   auto c1 = DivCeil(c, c0);
1060   auto hw = h * w;
1061   auto chw = c * hw;
1062   auto hwc0 = hw * c0;
1063   auto nchw = n * chw;
1064 
1065   auto hf_cnt = DivCeil(n, kNiSize);
1066   auto vf_cnt = c1 * hw;
1067   auto fractal_ele_cnt = c0 * kNiSize;
1068   auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
1069   auto dst_size = total_ele_cnt * size;
1070   if (dst_size != SizeToLong(args.device_size)) {
1071     MS_LOG(ERROR) << "Illegal total data size, "
1072                   << "dst size is :" << dst_size << ", device size is :" << args.device_size;
1073     return false;
1074   }
1075 
1076   for (int64_t vfi = 0; vfi < vf_cnt; vfi++) {
1077     auto vf_base_i = vfi * hf_cnt;  // vertical fractal matrix base index
1078     for (int64_t hfi = 0; hfi < hf_cnt; hfi++) {
1079       auto gfi = vf_base_i + hfi;  // global fractal matrix index
1080       auto src_n_offset = hfi * chw * kNiSize;
1081       auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
1082       for (int64_t row = 0; row < c0; row++) {
1083         auto src_ci = vfi / hw * c0 + row;
1084         auto src_row_offset = src_f_offset + row * hw;
1085         for (int64_t col = 0; col < kNiSize; col++) {
1086           auto src_ni = hfi * kNiSize + col;
1087           auto src_idx = src_row_offset + chw * col;
1088           auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
1089           auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
1090           SetData(size, pad_zero, src_idx, dst_idx, args, result);
1091         }
1092       }
1093     }
1094   }
1095   return true;
1096 }
1097 
NCHW_TO_FRAC_NZ(const FormatArgs & args,void * result)1098 bool FormatTransfer::NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result) {
1099   MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
1100   MS_EXCEPTION_IF_NULL(result);
1101   ShapeVector hw_shape;
1102   if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
1103     MS_LOG(ERROR) << "Trans shape failed..";
1104     return false;
1105   }
1106   if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1107     MS_LOG(ERROR) << "Invalid shape size.";
1108     return false;
1109   }
1110   auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1111   if (size < 1) {
1112     MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1113     return false;
1114   }
1115 
1116   auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1117   if (dst_size != SizeToLong(args.device_size)) {
1118     MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1119     return false;
1120   }
1121   auto times = hw_shape.at(0);
1122   auto h = hw_shape.at(hw_h);
1123   auto w = hw_shape.at(hw_w);
1124   auto hw = h * w;
1125 
1126   auto shape_size = args.device_shape.size();
1127   auto w1 = args.device_shape[shape_size - fnz_w1];
1128   auto h1 = args.device_shape[shape_size - fnz_h1];
1129   auto h0 = args.device_shape[shape_size - fnz_h0];
1130   auto w0 = args.device_shape[shape_size - fnz_w0];
1131   auto h1h0w0 = h1 * h0 * w0;
1132   auto w1h1h0w0 = w1 * h1h0w0;
1133   auto num_w1 = w / w0;
1134 
1135   for (int64_t times_idx = 0; times_idx < times; times_idx++) {
1136     auto times_head = times_idx * w1h1h0w0;
1137     auto src_times_head = times_idx * hw;
1138     for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1139       auto h1h0_head = times_head + h1h0_idx * w0;
1140       auto src_h_head = src_times_head + h1h0_idx * w;
1141       for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1142         for (int64_t i = 0; i < w0; ++i) {
1143           int64_t src_idx = src_h_head + w1_idx * w0 + i;
1144           int64_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
1145           SetData(size, false, src_idx, dst_idx, args, result);
1146         }
1147       }
1148       auto w1_head = num_w1 * w0;
1149       for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1150         auto src_w_idx = w1_head + w0_idx;
1151         int64_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1152         int64_t src_idx = src_h_head + src_w_idx;
1153         SetData(size, false, src_idx, dst_idx, args, result);
1154       }
1155     }
1156   }
1157   return true;
1158 }
1159 
NCHW_TO_FRAC_ZC04(const FormatArgs & args,void * result)1160 bool FormatTransfer::NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result) {
1161   // trans nchw to FracZc04
1162   MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
1163   MS_EXCEPTION_IF_NULL(result);
1164   int64_t size = 0;
1165   if (!CheckArgs(args, &size)) {
1166     MS_LOG(ERROR) << "Check args failed.";
1167     return false;
1168   }
1169   auto cube = GetCubeSizeByType(args.src_data_type);
1170   auto n = args.host_shape[kN];
1171   auto c = args.host_shape[kC];
1172   auto h = args.host_shape[kH];
1173   auto w = args.host_shape[kW];
1174   const int64_t c0 = 4;
1175   auto c1 = DivCeil(c, c0);
1176   auto hwc0 = h * w * c0;
1177   auto hwc = h * w * c;
1178   auto nhwc = n * h * w * c;
1179   auto n_cnt = DivCeil(n, kNiSize);
1180   auto v_cnt = DivCeil(h * w * c0 * c1, cube);
1181   int64_t dst_idx = 0;
1182 
1183   for (int64_t vi = 0; vi < v_cnt; vi++) {
1184     for (int64_t ni = 0; ni < n_cnt; ni++) {
1185       for (int64_t col = 0; col < kNiSize; col++) {
1186         for (int64_t row = 0; row < kNiSize; row++) {
1187           int64_t cur_cube_n = kNiSize * ni + col;
1188           int64_t cur_cube_c1hwc0 = kNiSize * vi + row;
1189           auto desc_g = cur_cube_n / n;
1190           auto desc_n = cur_cube_n % n;
1191           auto desc_c1 = cur_cube_c1hwc0 / hwc0;
1192           auto desc_c0 = cur_cube_c1hwc0 % c0;
1193           auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
1194           auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
1195           auto c_idx = desc_c1 * c0 + desc_c0;
1196           auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
1197           auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
1198           SetData(size, pad_zero, src_idx, dst_idx, args, result);
1199           dst_idx++;
1200         }
1201       }
1202     }
1203   }
1204   return true;
1205 }
1206 
NCHW_TO_NC1HWC0(const FormatArgs & args,void * result)1207 bool FormatTransfer::NCHW_TO_NC1HWC0(const FormatArgs &args, void *result) {
1208   MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
1209   MS_EXCEPTION_IF_NULL(result);
1210   auto size = Common4DCheck(args);
1211   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1212   if (total_size != SizeToLong(args.device_size)) {
1213     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1214     return false;
1215   }
1216 
1217   auto n = args.host_shape[kN];
1218   auto c = args.host_shape[kC];
1219   auto h = args.host_shape[kH];
1220   auto w = args.host_shape[kW];
1221   auto c0 = GetCubeSizeByType(args.src_data_type);
1222   if (args.device_format == kOpFormat_NC1HWC0_C04) {
1223     c0 = kCubeSize_C04;
1224   }
1225   auto c1 = DivCeil(c, c0);
1226   auto hw = h * w;
1227   auto chw = c * hw;
1228   auto c1hwc0 = c1 * hw * c0;
1229   auto wc0 = w * c0;
1230 
1231   for (int64_t n_idx = 0; n_idx < n; n_idx++) {
1232     int64_t n_head_addr = n_idx * c1hwc0;
1233     for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) {
1234       int64_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
1235       for (int64_t h_idx = 0; h_idx < h; h_idx++) {
1236         int64_t h_head_addr = c1_head_addr + h_idx * wc0;
1237         for (int64_t w_idx = 0; w_idx < w; w_idx++) {
1238           int64_t w_head_addr = h_head_addr + w_idx * c0;
1239           for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) {
1240             int64_t dst_idx = c0_idx + w_head_addr;
1241             int64_t c_idx = c0_idx + c1_idx * c0;
1242             int64_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
1243             auto pad_zero = c_idx >= c;
1244             SetData(size, pad_zero, src_idx, dst_idx, args, result);
1245           }
1246         }
1247       }
1248     }
1249   }
1250   return true;
1251 }
1252 
NCHW_TO_NC1HWC04(const FormatArgs & args,void * result)1253 bool FormatTransfer::NCHW_TO_NC1HWC04(const FormatArgs &args, void *result) {
1254   MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
1255   return NCHW_TO_NC1HWC0(args, result);
1256 }
1257 
NCHW_TO_C1HWNCOC0(const FormatArgs & args,void * result)1258 bool FormatTransfer::NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result) {
1259   // trans nchw to c1hwncoc0
1260   MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
1261   MS_EXCEPTION_IF_NULL(result);
1262   int64_t size = 0;
1263   if (!CheckArgs(args, &size)) {
1264     MS_LOG(ERROR) << "Check args failed.";
1265     return false;
1266   }
1267   auto n = args.host_shape[kN];
1268   auto c = args.host_shape[kC];
1269   auto h = args.host_shape[kH];
1270   auto w = args.host_shape[kW];
1271   const int co_idx = 4;
1272   const int c0_idx = 5;
1273   auto c1 = args.device_shape[0];
1274   auto co = args.device_shape[co_idx];
1275   auto c0 = args.device_shape[c0_idx];
1276 
1277   for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
1278     for (int64_t h_i = 0; h_i < h; h_i++) {
1279       for (int64_t w_i = 0; w_i < w; w_i++) {
1280         for (int64_t n_i = 0; n_i < n; n_i++) {
1281           for (int64_t co_i = 0; co_i < co; co_i++) {
1282             for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
1283               int64_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
1284                                 co_i * c0 + c0_i;
1285               int64_t c_i = c0_i + c1_i * c0;
1286               int64_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1287               auto pad_zero = !(c_i < c && c0_i == co_i);
1288               SetData(size, pad_zero, src_idx, dst_idx, args, result);
1289             }
1290           }
1291         }
1292       }
1293     }
1294   }
1295   return true;
1296 }
1297 
NCDHW_TO_NDC1HWC0(const FormatArgs & args,void * result)1298 bool FormatTransfer::NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result) {
1299   MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
1300   MS_EXCEPTION_IF_NULL(result);
1301 
1302   if (args.host_shape.size() != kDim5) {
1303     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1304     return false;
1305   }
1306   auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1307   if (size < 1) {
1308     MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1309     return false;
1310   }
1311   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1312   if (total_size != SizeToLong(args.device_size)) {
1313     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1314     return false;
1315   }
1316 
1317   auto n = args.host_shape[N_ncdhw];
1318   auto c = args.host_shape[C_ncdhw];
1319   auto d = args.host_shape[D_ncdhw];
1320   auto h = args.host_shape[H_ncdhw];
1321   auto w = args.host_shape[W_ncdhw];
1322   auto c0 = GetCubeSizeByType(args.src_data_type);
1323   auto c1 = DivCeil(c, c0);
1324   const int64_t cdhw = c * d * h * w;
1325   const int64_t dhw = d * h * w;
1326   const int64_t hw = h * w;
1327   const int64_t dc1hwc0 = d * c1 * h * w * c0;
1328   const int64_t c1hwc0 = c1 * h * w * c0;
1329   const int64_t hwc0 = h * w * c0;
1330   const int64_t wc0 = w * c0;
1331 
1332   for (int64_t n_i = 0; n_i < n; n_i++) {
1333     int64_t n_head = n_i * dc1hwc0;
1334     for (int64_t d_i = 0; d_i < d; d_i++) {
1335       int64_t d_head = n_head + d_i * c1hwc0;
1336       for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
1337         int64_t c1_head = d_head + c1_i * hwc0;
1338         for (int64_t h_i = 0; h_i < h; h_i++) {
1339           int64_t h_head = c1_head + h_i * wc0;
1340           for (int64_t w_i = 0; w_i < w; w_i++) {
1341             int64_t w_head = h_head + w_i * c0;
1342             for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
1343               int64_t dst_i = c0_i + w_head;
1344               int64_t c_i = c0_i + c1_i * c0;
1345               int64_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
1346               auto pad_zero = c_i >= c;
1347               SetData(size, pad_zero, src_i, dst_i, args, result);
1348             }
1349           }
1350         }
1351       }
1352     }
1353   }
1354   return true;
1355 }
1356 
NCDHW_TO_FRAC_Z3D(const FormatArgs & args,void * result)1357 bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
1358   MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
1359   MS_EXCEPTION_IF_NULL(result);
1360 
1361   if (args.host_shape.size() != kDim5) {
1362     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1363     return false;
1364   }
1365   auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1366   if (size < 1) {
1367     MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1368     return false;
1369   }
1370   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1371   if (total_size != SizeToLong(args.device_size)) {
1372     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1373     return false;
1374   }
1375 
1376   auto n = args.host_shape[N_ncdhw];
1377   auto c = args.host_shape[C_ncdhw];
1378   auto d = args.host_shape[D_ncdhw];
1379   auto h = args.host_shape[H_ncdhw];
1380   auto w = args.host_shape[W_ncdhw];
1381 
1382   auto n1n0 = DivCeil(n, kNiSize) * kNiSize;
1383   auto c0 = GetCubeSizeByType(args.src_data_type);
1384   auto c1 = DivCeil(c, c0);
1385   auto hw = h * w;
1386   auto dhw = d * hw;
1387   auto cdhw = c * dhw;
1388   auto n1n0c0 = n1n0 * c0;
1389   auto wn1n0c0 = w * n1n0c0;
1390   auto hwn1n0c0 = h * wn1n0c0;
1391   auto c1hwn1n0c0 = c1 * hwn1n0c0;
1392 
1393   for (int64_t d_i = 0; d_i < d; d_i++) {
1394     for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
1395       for (int64_t h_i = 0; h_i < h; h_i++) {
1396         for (int64_t w_i = 0; w_i < w; w_i++) {
1397           for (int64_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
1398             for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
1399               auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
1400               // ncdhw
1401               int64_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
1402               auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
1403               SetData(size, pad_zero, src_i, dst_i, args, result);
1404             }
1405           }
1406         }
1407       }
1408     }
1409   }
1410   return true;
1411 }
1412 
NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs & args,void * result,bool to_device,int64_t groups)1413 bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
1414   MS_EXCEPTION_IF_NULL(result);
1415   auto size = Common4DCheck(args);
1416   auto n_dim = args.host_shape[kN];
1417   auto c_dim = args.host_shape[kC];
1418   auto h_dim = args.host_shape[kH];
1419   auto w_dim = args.host_shape[kW];
1420   auto d_dim = 1;
1421   auto cin_ori = c_dim;
1422   if (groups <= 0) {
1423     MS_LOG(INTERNAL_EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
1424   }
1425   // cppcheck-suppress *
1426   auto cout_ori = n_dim / groups;
1427   if (cin_ori == 0 || cout_ori == 0) {
1428     MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
1429     return false;
1430   }
1431   auto cube_k = GetCubeSizeByType(args.src_data_type);
1432   auto e_mult = std::min(Lcm(Lcm(cin_ori, cube_k) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), groups);
1433   if (e_mult == 0) {
1434     MS_LOG(INTERNAL_EXCEPTION) << "The value of e_mult should be greater than 0, but got " << e_mult;
1435   }
1436   auto cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k;
1437   auto cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
1438   // cppcheck-suppress *
1439   auto c1_dim = cin_opt / cube_k;
1440   auto dst_size =
1441     to_device ? abstract::ShapeSize(args.device_shape) * size : abstract::ShapeSize(args.host_shape) * size;
1442   if (dst_size == 0) {
1443     return true;
1444   }
1445   auto ret = memset_s(result, LongToSize(dst_size), 0, LongToSize(dst_size));
1446   if (ret != EOK) {
1447     MS_LOG(ERROR) << "memset failed";
1448     return false;
1449   }
1450   for (int64_t g = 0; g < groups; ++g) {
1451     for (int64_t d = 0; d < d_dim; ++d) {
1452       for (int64_t c = 0; c < c_dim; ++c) {
1453         for (int64_t h = 0; h < h_dim; ++h) {
1454           for (int64_t w = 0; w < w_dim; ++w) {
1455             for (int64_t n = 0; n < cout_ori; ++n) {
1456               int64_t e_val = g % e_mult;
1457               int64_t dst_ci = e_val * cin_ori + c;
1458               int64_t dst_co = e_val * cout_ori + n;
1459               int64_t src_co = g * cout_ori + n;
1460               int64_t temporary = dst_ci % cube_k;
1461               int64_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * cube_k +
1462                                 d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
1463                                 (dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + h * w_dim * cout_opt * cube_k +
1464                                 w * cout_opt * cube_k + dst_co * cube_k + temporary;
1465               int64_t hst_idx =
1466                 src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
1467               if (to_device) {
1468                 SetData(size, false, hst_idx, dev_idx, args, result);
1469               } else {
1470                 SetData(size, false, dev_idx, hst_idx, args, result);
1471               }
1472             }
1473           }
1474         }
1475       }
1476     }
1477   }
1478   return true;
1479 }
1480 
NC1HWC0_TO_NCHW(const FormatArgs & args,void * result)1481 bool FormatTransfer::NC1HWC0_TO_NCHW(const FormatArgs &args, void *result) {
1482   MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
1483   MS_EXCEPTION_IF_NULL(result);
1484   auto size = Common4DCheck(args);
1485   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1486   if (total_size != SizeToLong(args.device_size)) {
1487     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1488     return false;
1489   }
1490 
1491   auto n = args.host_shape[kN];
1492   auto c = args.host_shape[kC];
1493   auto h = args.host_shape[kH];
1494   auto w = args.host_shape[kW];
1495   auto c1 = args.device_shape[kDim1];
1496   auto c0 = args.device_shape[kDim4];
1497 
1498   auto hw = h * w;
1499   auto chw = c * hw;
1500   auto wc0 = w * c0;
1501   auto hwc0 = h * wc0;
1502   auto c1hwc0 = c1 * hwc0;
1503 
1504   for (int64_t n_idx = 0; n_idx < n; n_idx++) {
1505     int64_t n_head_addr = n_idx * chw;
1506     for (int64_t c_idx = 0; c_idx < c; c_idx++) {
1507       int64_t c_head_addr = n_head_addr + c_idx * hw;
1508       for (int64_t h_idx = 0; h_idx < h; h_idx++) {
1509         int64_t h_head_addr = c_head_addr + h_idx * w;
1510         for (int64_t w_idx = 0; w_idx < w; w_idx++) {
1511           int64_t dst_idx = h_head_addr + w_idx;
1512           int64_t c1_idx = c_idx / c0;
1513           int64_t c0_idx = c_idx % c0;
1514           int64_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
1515           SetData(size, false, src_idx, dst_idx, args, result);
1516         }
1517       }
1518     }
1519   }
1520   return true;
1521 }
1522 
NC1HWC04_TO_NCHW(const FormatArgs & args,void * result)1523 bool FormatTransfer::NC1HWC04_TO_NCHW(const FormatArgs &args, void *result) {
1524   MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
1525   return NC1HWC0_TO_NCHW(args, result);
1526 }
1527 
C1HWNCOC0_TO_NCHW(const FormatArgs & args,void * result)1528 bool FormatTransfer::C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result) {
1529   // trans c1hwncoc0 to nchw
1530   MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
1531   MS_EXCEPTION_IF_NULL(result);
1532   int64_t size = 0;
1533   if (!CheckArgs(args, &size)) {
1534     MS_LOG(ERROR) << "Check args failed.";
1535     return false;
1536   }
1537   auto n = args.host_shape[kN];
1538   auto c = args.host_shape[kC];
1539   auto h = args.host_shape[kH];
1540   auto w = args.host_shape[kW];
1541   const int co_idx = 4;
1542   const int c0_idx = 5;
1543   auto co = args.device_shape[co_idx];
1544   auto c0 = args.device_shape[c0_idx];
1545   auto cube_k = GetCubeSizeByType(args.src_data_type);
1546   for (int64_t n_i = 0; n_i < n; n_i++) {
1547     for (int64_t c_i = 0; c_i < c; c_i++) {
1548       for (int64_t h_i = 0; h_i < h; h_i++) {
1549         for (int64_t w_i = 0; w_i < w; w_i++) {
1550           int64_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1551           int64_t c1_i = c_i / cube_k;
1552           int64_t c0_i = c_i % cube_k;
1553           int64_t co_i = c0_i;
1554           int64_t src_idx =
1555             c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
1556           SetData(size, false, src_idx, dst_idx, args, result);
1557         }
1558       }
1559     }
1560   }
1561   return true;
1562 }
1563 
FRAC_Z_TO_NCHW(const FormatArgs & args,void * result)1564 bool FormatTransfer::FRAC_Z_TO_NCHW(const FormatArgs &args, void *result) {
1565   MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
1566   MS_EXCEPTION_IF_NULL(result);
1567   auto size = Common4DCheck(args);
1568   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1569   if (total_size != SizeToLong(args.device_size)) {
1570     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1571     return false;
1572   }
1573 
1574   auto n0 = args.device_shape.at(fz_n0);
1575   auto ni = args.device_shape.at(fz_ni);
1576   auto c0 = args.device_shape.at(fz_c0);
1577   auto n = args.host_shape[kN];
1578   auto c = args.host_shape[kC];
1579   auto h = args.host_shape[kH];
1580   auto w = args.host_shape[kW];
1581   auto nc = ni * n0;
1582   auto ncc0 = nc * c0;
1583   auto wncc0 = w * ncc0;
1584   auto hwncc0 = h * wncc0;
1585   auto hw = h * w;
1586   auto chw = c * hw;
1587 
1588   for (int64_t n_idx = 0; n_idx < n; n_idx++) {
1589     int64_t n_head_addr = n_idx * chw;
1590     for (int64_t c_idx = 0; c_idx < c; c_idx++) {
1591       int64_t c_head_addr = n_head_addr + c_idx * hw;
1592       for (int64_t h_idx = 0; h_idx < h; h_idx++) {
1593         int64_t h_head_addr = c_head_addr + h_idx * w;
1594         for (int64_t w_idx = 0; w_idx < w; w_idx++) {
1595           auto dst_idx = h_head_addr + w_idx;
1596           auto c1_idx = c_idx / c0;
1597           auto c0_idx = c_idx % c0;
1598           auto nc_idx = n_idx;
1599           auto src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
1600           SetData(size, false, src_idx, dst_idx, args, result);
1601         }
1602       }
1603     }
1604   }
1605   return true;
1606 }
1607 
FRAC_NZ_TO_NCHW(const FormatArgs & args,void * result)1608 bool FormatTransfer::FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result) {
1609   MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
1610   MS_EXCEPTION_IF_NULL(result);
1611   ShapeVector hw_shape;
1612   if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
1613     MS_LOG(ERROR) << "Trans shape failed..";
1614     return false;
1615   }
1616   if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1617     MS_LOG(ERROR) << "Invalid shape size.";
1618     return false;
1619   }
1620   auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1621   if (size < 1) {
1622     MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1623     return false;
1624   }
1625 
1626   auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1627   if (dst_size != SizeToLong(args.device_size)) {
1628     MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1629     return false;
1630   }
1631   auto times = hw_shape.at(0);
1632   auto h = hw_shape.at(hw_h);
1633   auto w = hw_shape.at(hw_w);
1634   auto hw = h * w;
1635 
1636   auto shape_size = args.device_shape.size();
1637   auto w1 = args.device_shape[shape_size - fnz_w1];
1638   auto h1 = args.device_shape[shape_size - fnz_h1];
1639   auto h0 = args.device_shape[shape_size - fnz_h0];
1640   auto w0 = args.device_shape[shape_size - fnz_w0];
1641   auto h1h0w0 = h1 * h0 * w0;
1642   auto w1h1h0w0 = w1 * h1h0w0;
1643   auto num_w1 = w / w0;
1644 
1645   for (int64_t times_idx = 0; times_idx < times; times_idx++) {
1646     auto times_head = times_idx * w1h1h0w0;
1647     auto src_times_head = times_idx * hw;
1648     for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1649       auto h1h0_head = times_head + h1h0_idx * w0;
1650       auto src_h_head = src_times_head + h1h0_idx * w;
1651       for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1652         for (int64_t i = 0; i < w0; ++i) {
1653           int64_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
1654           int64_t dst_idx = src_h_head + w1_idx * w0 + i;
1655           SetData(size, false, src_idx, dst_idx, args, result);
1656         }
1657       }
1658       auto w1_head = num_w1 * w0;
1659       for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1660         auto src_w_idx = w1_head + w0_idx;
1661         int64_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1662         int64_t dst_idx = src_h_head + src_w_idx;
1663         SetData(size, false, src_idx, dst_idx, args, result);
1664       }
1665     }
1666   }
1667   return true;
1668 }
1669 
FRAC_Z3D_TO_NCDHW(const FormatArgs & args,void * result)1670 bool FormatTransfer::FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result) {
1671   MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
1672   MS_EXCEPTION_IF_NULL(result);
1673 
1674   if (args.host_shape.size() != kDim5) {
1675     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1676     return false;
1677   }
1678   auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1679   if (size < 1) {
1680     MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1681     return false;
1682   }
1683   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1684   if (total_size != SizeToLong(args.device_size)) {
1685     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1686     return false;
1687   }
1688   auto n = args.host_shape[N_ncdhw];
1689   auto c = args.host_shape[C_ncdhw];
1690   auto d = args.host_shape[D_ncdhw];
1691   auto h = args.host_shape[H_ncdhw];
1692   auto w = args.host_shape[W_ncdhw];
1693   const int kFZ3D_C0 = 3;
1694   auto c0 = args.device_shape[kFZ3D_C0];
1695   auto cube_k = GetCubeSizeByType(args.src_data_type);
1696   auto c1 = DivCeil(c, cube_k);
1697   auto n1n0 = DivCeil(n, kNiSize) * kNiSize;
1698   auto n1n0c0 = n1n0 * c0;
1699   auto wn1n0c0 = w * n1n0c0;
1700   auto hwn1n0c0 = h * wn1n0c0;
1701   auto c1hwn1n0c0 = c1 * hwn1n0c0;
1702   auto hw = h * w;
1703   auto dhw = d * hw;
1704   auto cdhw = c * dhw;
1705 
1706   for (int64_t n_i = 0; n_i < n; n_i++) {
1707     int64_t n_head = n_i * cdhw;
1708     for (int64_t c_i = 0; c_i < c; c_i++) {
1709       int64_t c_head = n_head + c_i * dhw;
1710       for (int64_t d_i = 0; d_i < d; d_i++) {
1711         int64_t d_head = c_head + d_i * hw;
1712         for (int64_t h_i = 0; h_i < h; h_i++) {
1713           int64_t h_head = d_head + h_i * w;
1714           for (int64_t w_i = 0; w_i < w; w_i++) {
1715             int64_t dst_i = h_head + w_i;
1716             int64_t c1_i = c_i / c0;
1717             int64_t c0_i = c_i % c0;
1718             int64_t nc_i = n_i;
1719             int64_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
1720             SetData(size, false, src_i, dst_i, args, result);
1721           }
1722         }
1723       }
1724     }
1725   }
1726   return true;
1727 }
1728 
NDC1HWC0_TO_NCDHW(const FormatArgs & args,void * result)1729 bool FormatTransfer::NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result) {
1730   MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
1731   MS_EXCEPTION_IF_NULL(result);
1732 
1733   if (args.host_shape.size() != kDim5) {
1734     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1735     return false;
1736   }
1737   auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1738   if (size < 1) {
1739     MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
1740     return false;
1741   }
1742   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1743   if (total_size != SizeToLong(args.device_size)) {
1744     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1745     return false;
1746   }
1747   auto n = args.host_shape[N_ncdhw];
1748   auto c = args.host_shape[C_ncdhw];
1749   auto d = args.host_shape[D_ncdhw];
1750   auto h = args.host_shape[H_ncdhw];
1751   auto w = args.host_shape[W_ncdhw];
1752   auto c1 = args.device_shape[C1_ndc1hwc0];
1753   auto c0 = args.device_shape[C0_ndc1hwc0];
1754   const int64_t cdhw = c * d * h * w;
1755   const int64_t dhw = d * h * w;
1756   const int64_t hw = h * w;
1757   const int64_t dc1hwc0 = d * c1 * h * w * c0;
1758   const int64_t c1hwc0 = c1 * h * w * c0;
1759   const int64_t hwc0 = h * w * c0;
1760   const int64_t wc0 = w * c0;
1761 
1762   for (int64_t n_i = 0; n_i < n; n_i++) {
1763     int64_t n_head = n_i * cdhw;
1764     for (int64_t c_i = 0; c_i < c; c_i++) {
1765       int64_t c_head = n_head + c_i * dhw;
1766       for (int64_t d_i = 0; d_i < d; d_i++) {
1767         int64_t d_head = c_head + d_i * hw;
1768         for (int64_t h_i = 0; h_i < h; h_i++) {
1769           int64_t h_head = d_head + h_i * w;
1770           for (int64_t w_i = 0; w_i < w; w_i++) {
1771             int64_t dst_i = h_head + w_i;
1772             int64_t c1_i = c_i / c0;
1773             int64_t c0_i = c_i % c0;
1774             auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
1775             SetData(size, false, src_idx, dst_i, args, result);
1776           }
1777         }
1778       }
1779     }
1780   }
1781   return true;
1782 }
1783 
FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs & args,void * result,int64_t groups)1784 bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups) {
1785   MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
1786   return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, false, groups);
1787 }
1788 
Common4DCheck(const FormatArgs & args)1789 int64_t FormatTransfer::Common4DCheck(const FormatArgs &args) {
1790   if (args.host_shape.size() != kDim4) {
1791     MS_LOG(INTERNAL_EXCEPTION) << "Invalid host shape, host shape dims:" << args.host_shape.size()
1792                                << ", expect dims:" << kNchwDims;
1793   }
1794   auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
1795   if (size < 1) {
1796     MS_LOG(INTERNAL_EXCEPTION) << "Illegal dtype: " << args.src_data_type;
1797   }
1798   return size;
1799 }
1800 
1801 // ########################  RANGE TRANS ########################
GetRealRange(const RangePair & ori_range,const std::string & format,const TypeId & type,const std::string & padding_str) const1802 RangePair ShapeRangeTransfer::GetRealRange(const RangePair &ori_range, const std::string &format, const TypeId &type,
1803                                            const std::string &padding_str) const {
1804   const std::set<std::string> no_need_change = {kOpFormat_ND, kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NCDHW};
1805   using RangeTransfer = std::function<RangePair(const RangePair &, const TypeId &)>;
1806   const std::map<std::string, RangeTransfer> format_range_map = {{kOpFormat_NHWC, NHWCRange},
1807                                                                  {kOpFormat_HWCN, HWCNRange},
1808                                                                  {kOpFormat_FRAC_Z, FRAC_ZRange},
1809                                                                  {kOpFormat_NC1HWC0, NC1HWC0Range},
1810                                                                  {kOpFormat_NDC1HWC0, NDC1HWC0Range},
1811                                                                  {kOpFormat_C1HWNCoC0, C1HWNCOC0Range},
1812                                                                  {kOpFormat_NC1HWC0_C04, NC1HWC04Range},
1813                                                                  {kOpFormat_FRACTAL_Z_3D, FRAC_Z_3DRange},
1814                                                                  {kOpFormat_FRACTAL_Z_C04, FRAC_ZC04Range}};
1815   if (no_need_change.find(format) != no_need_change.end()) {
1816     return ori_range;
1817   }
1818   // kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_FRAC_NZ no need pad range
1819   if (format == kOpFormat_FRACTAL_ZN_LSTM) {
1820     return FRAC_ZN_LSTMRange(ori_range, type);
1821   }
1822   if (format == kOpFormat_FRAC_NZ) {
1823     return FRAC_NZRange(ori_range, type);
1824   }
1825   auto temp_range = ori_range;
1826   if (ori_range.size() < kDim4 && !IsOneOf3DFormat(format)) {
1827     MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 4, so padding the range firstly";
1828     temp_range = PaddingRangeTo4D(ori_range, padding_str);
1829   }
1830   if (ori_range.size() < kDim5 && IsOneOf3DFormat(format)) {
1831     MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 5, so padding the range firstly";
1832     temp_range = PaddingRangeTo5D(ori_range, padding_str);
1833   }
1834   auto iter = format_range_map.find(format);
1835   if (iter == format_range_map.end()) {
1836     MS_LOG(INFO) << "Can not find a supported format: " << format << ", using default range";
1837     return ori_range;
1838   }
1839   return iter->second(temp_range, type);
1840 }
1841 
NHWCRange(const RangePair & ori_range,const TypeId &)1842 RangePair ShapeRangeTransfer::NHWCRange(const RangePair &ori_range, const TypeId &) {
1843   RangePair dst_range;
1844   dst_range.push_back(ori_range[kN]);
1845   dst_range.push_back(ori_range[kH]);
1846   dst_range.push_back(ori_range[kW]);
1847   dst_range.push_back(ori_range[kC]);
1848   return dst_range;
1849 }
1850 
HWCNRange(const RangePair & ori_range,const TypeId &)1851 RangePair ShapeRangeTransfer::HWCNRange(const RangePair &ori_range, const TypeId &) {
1852   RangePair dst_range;
1853   dst_range.push_back(ori_range[kH]);
1854   dst_range.push_back(ori_range[kW]);
1855   dst_range.push_back(ori_range[kC]);
1856   dst_range.push_back(ori_range[kN]);
1857   return dst_range;
1858 }
1859 
NC1HWC04Range(const RangePair & ori_range,const TypeId &)1860 RangePair ShapeRangeTransfer::NC1HWC04Range(const RangePair &ori_range, const TypeId &) {
1861   RangePair dst_range;
1862   const std::pair<int64_t, int64_t> c0 = {k4, k4};
1863   auto tmp_max = CalMaxShape(ori_range[kC].second, (ori_range[kC].second + k4 - 1) / k4);
1864   const std::pair<int64_t, int64_t> c1 = {(ori_range[kC].first + k4 - 1) / k4, tmp_max};
1865   dst_range.push_back(ori_range[kN]);
1866   dst_range.push_back(c1);
1867   dst_range.push_back(ori_range[kH]);
1868   dst_range.push_back(ori_range[kW]);
1869   dst_range.push_back(c0);
1870   return dst_range;
1871 }
1872 
FRAC_ZC04Range(const RangePair & ori_range,const TypeId &)1873 RangePair ShapeRangeTransfer::FRAC_ZC04Range(const RangePair &ori_range, const TypeId &) {
1874   RangePair dst_range;
1875   const std::pair<int64_t, int64_t> c0 = {k4, k4};
1876   const std::pair<int64_t, int64_t> c16 = {kNiSize, kNiSize};
1877 
1878   auto tmp_max = CalMaxShape(c0.second * ori_range[kH].second * ori_range[kW].second,
1879                              (c0.second * ori_range[kH].second * ori_range[kW].second + kNiSize - 1) / kNiSize);
1880   const std::pair<int64_t, int64_t> first_dim = {
1881     (c0.first * ori_range[kH].first * ori_range[kW].first + kNiSize - 1) / kNiSize, tmp_max};
1882 
1883   tmp_max = CalMaxShape(ori_range[kN].second, (ori_range[kN].second + kNiSize - 1) / kNiSize);
1884   const std::pair<int64_t, int64_t> no = {(ori_range[kN].first + kNiSize - 1) / kNiSize, tmp_max};
1885   dst_range.push_back(first_dim);
1886   dst_range.push_back(no);
1887   dst_range.push_back(c16);
1888   dst_range.push_back(c16);
1889   return dst_range;
1890 }
1891 
FRAC_ZRange(const RangePair & ori_range,const TypeId & type)1892 RangePair ShapeRangeTransfer::FRAC_ZRange(const RangePair &ori_range, const TypeId &type) {
1893   RangePair dst_range;
1894   auto cube = GetCubeSizeByType(type);
1895   const std::pair<int64_t, int64_t> c0 = {cube, cube};
1896 
1897   auto tmp_max = CalMaxShape(ori_range[kN].second, ((ori_range[kN].second + kNiSize - 1) / kNiSize) * kNiSize);
1898 
1899   const std::pair<int64_t, int64_t> cout16 = {((ori_range[kN].first + kNiSize - 1) / kNiSize) * kNiSize, tmp_max};
1900 
1901   tmp_max = CalMaxShape(ori_range[kC].second, ((ori_range[kC].second + cube - 1) / cube) * cube);
1902   const std::pair<int64_t, int64_t> cin16 = {((ori_range[kC].first + cube - 1) / cube) * cube, tmp_max};
1903 
1904   tmp_max = CalMaxShape(ori_range[kH].second * ori_range[kW].second * cin16.second,
1905                         ori_range[kH].second * ori_range[kW].second * cin16.second / cube);
1906   const std::pair<int64_t, int64_t> r0 = {ori_range[kH].first * ori_range[kW].first * cin16.first / cube, tmp_max};
1907 
1908   tmp_max = CalMaxShape(cin16.second, cout16.second / kNiSize);
1909   const std::pair<int64_t, int64_t> r1 = {cout16.first / kNiSize, tmp_max};
1910   const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
1911   dst_range.push_back(r0);
1912   dst_range.push_back(r1);
1913   dst_range.push_back(co);
1914   dst_range.push_back(c0);
1915   return dst_range;
1916 }
1917 
FRAC_NZRange(const RangePair & ori_range,const TypeId & type)1918 RangePair ShapeRangeTransfer::FRAC_NZRange(const RangePair &ori_range, const TypeId &type) {
1919   RangePair dst_range;
1920   auto cube = GetCubeSizeByType(type);
1921   auto ori_size = ori_range.size();
1922   if (ori_size < kDims2) {
1923     return ori_range;
1924   } else {
1925     (void)std::copy(ori_range.begin(), ori_range.end() - kDims2, std::back_inserter(dst_range));
1926   }
1927   const std::pair<int64_t, int64_t> c0 = {cube, cube};
1928   auto tmp_max = CalMaxShape(ori_range[ori_size - 1].second, (ori_range[ori_size - 1].second - 1) / cube + 1);
1929   const std::pair<int64_t, int64_t> w1 = {(ori_range[ori_size - 1].first - 1) / cube + 1, tmp_max};
1930   tmp_max = CalMaxShape(ori_range[ori_size - kDims2].second, (ori_range[ori_size - kDims2].second - 1) / kNiSize + 1);
1931   const std::pair<int64_t, int64_t> h1 = {(ori_range[ori_size - kDims2].first - 1) / kNiSize + 1, tmp_max};
1932   const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
1933   dst_range.push_back(w1);
1934   dst_range.push_back(h1);
1935   dst_range.push_back(co);
1936   dst_range.push_back(c0);
1937   return dst_range;
1938 }
1939 
NC1HWC0Range(const RangePair & ori_range,const TypeId & type)1940 RangePair ShapeRangeTransfer::NC1HWC0Range(const RangePair &ori_range, const TypeId &type) {
1941   RangePair dst_range;
1942   auto cube = GetCubeSizeByType(type);
1943   const std::pair<int64_t, int64_t> c0 = {cube, cube};
1944   auto tmp_max = CalMaxShape(ori_range[kC].second, (ori_range[kC].second + cube - 1) / cube);
1945   const std::pair<int64_t, int64_t> c1 = {(ori_range[kC].first + cube - 1) / cube, tmp_max};
1946   dst_range.push_back(ori_range[kN]);
1947   dst_range.push_back(c1);
1948   dst_range.push_back(ori_range[kH]);
1949   dst_range.push_back(ori_range[kW]);
1950   dst_range.push_back(c0);
1951   return dst_range;
1952 }
1953 
FRAC_ZN_LSTMRange(const RangePair & ori_range,const TypeId &)1954 RangePair ShapeRangeTransfer::FRAC_ZN_LSTMRange(const RangePair &ori_range, const TypeId &) {
1955   RangePair dst_range;
1956   const std::pair<int64_t, int64_t> c0 = {k4, k4};
1957   const std::pair<int64_t, int64_t> c16 = {k4, k4};
1958 
1959   auto tmp_max = CalMaxShape(ori_range[kN].second, ori_range[kN].second / c0.second);
1960   const std::pair<int64_t, int64_t> h = {ori_range[kN].first / c0.first, tmp_max};
1961 
1962   tmp_max = CalMaxShape(ori_range[kC].second * h.second, ori_range[kC].second - h.second);
1963   const std::pair<int64_t, int64_t> i = {ori_range[kC].first - h.first, tmp_max};
1964 
1965   tmp_max = CalMaxShape(i.second * h.second, (i.second + kCube16 - 1) / kCube16 + (h.second + kCube16 - 1) / kCube16);
1966   const std::pair<int64_t, int64_t> first_dim = {(i.first + kCube16 - 1) / kCube16 + (h.first + kCube16 - 1) / kCube16,
1967                                                  tmp_max};
1968 
1969   tmp_max = CalMaxShape(h.second, c0.second * ((h.second + kCube16 - 1) / kCube16));
1970   const std::pair<int64_t, int64_t> second = {c0.first * ((h.first + kCube16 - 1) / kCube16), tmp_max};
1971   dst_range.push_back(first_dim);
1972   dst_range.push_back(second);
1973   dst_range.push_back(c16);
1974   dst_range.push_back(c16);
1975   return dst_range;
1976 }
1977 
NDC1HWC0Range(const RangePair & ori_range,const TypeId & type)1978 RangePair ShapeRangeTransfer::NDC1HWC0Range(const RangePair &ori_range, const TypeId &type) {
1979   RangePair dst_range;
1980   auto cube = GetCubeSizeByType(type);
1981   const std::pair<int64_t, int64_t> c0 = {cube, cube};
1982   auto tmp_max = CalMaxShape(ori_range[C_ncdhw].second, (ori_range[C_ncdhw].second + cube - 1) / cube);
1983   const std::pair<int64_t, int64_t> c1 = {(ori_range[C_ncdhw].first + cube - 1) / cube, tmp_max};
1984   dst_range.push_back(ori_range[N_ncdhw]);
1985   dst_range.push_back(ori_range[D_ncdhw]);
1986   dst_range.push_back(c1);
1987   dst_range.push_back(ori_range[H_ncdhw]);
1988   dst_range.push_back(ori_range[W_ncdhw]);
1989   dst_range.push_back(c0);
1990   return dst_range;
1991 }
1992 
C1HWNCOC0Range(const RangePair & ori_range,const TypeId & type)1993 RangePair ShapeRangeTransfer::C1HWNCOC0Range(const RangePair &ori_range, const TypeId &type) {
1994   RangePair dst_range;
1995   auto cube = GetCubeSizeByType(type);
1996   const std::pair<int64_t, int64_t> c0 = {cube, cube};
1997   auto tmp_max = CalMaxShape(ori_range[kC].second, (ori_range[kC].second - 1) / cube + 1);
1998   const std::pair<int64_t, int64_t> r1 = {(ori_range[kC].first - 1) / cube + 1, tmp_max};
1999   dst_range.push_back(r1);
2000   dst_range.push_back(ori_range[kH]);
2001   dst_range.push_back(ori_range[kW]);
2002   dst_range.push_back(ori_range[kN]);
2003   dst_range.push_back(c0);
2004   dst_range.push_back(c0);
2005   return dst_range;
2006 }
2007 
FRAC_Z_3DRange(const RangePair & ori_range,const TypeId & type)2008 RangePair ShapeRangeTransfer::FRAC_Z_3DRange(const RangePair &ori_range, const TypeId &type) {
2009   RangePair dst_range;
2010   auto cube = GetCubeSizeByType(type);
2011   const std::pair<int64_t, int64_t> c0 = {cube, cube};
2012   auto tmp_max = CalMaxShape(ori_range[C_ncdhw].second, (ori_range[C_ncdhw].second + cube - 1) / cube);
2013   const std::pair<int64_t, int64_t> c1 = {(ori_range[C_ncdhw].first + cube - 1) / cube, tmp_max};
2014 
2015   tmp_max = CalMaxShape(ori_range[N_ncdhw].second, (ori_range[N_ncdhw].second + kNiSize - 1) / kNiSize);
2016   const std::pair<int64_t, int64_t> n1 = {(ori_range[N_ncdhw].first + kNiSize - 1) / kNiSize, tmp_max};
2017 
2018   const int64_t r1_0 = ori_range[D_ncdhw].first * c1.first * ori_range[H_ncdhw].first * ori_range[W_ncdhw].first;
2019   const int64_t r1_1 =
2020     CalMaxShape(ori_range[D_ncdhw].second * c1.second * ori_range[H_ncdhw].second * ori_range[W_ncdhw].second,
2021                 ori_range[D_ncdhw].second * c1.second * ori_range[H_ncdhw].second * ori_range[W_ncdhw].second);
2022   const std::pair<int64_t, int64_t> r1 = {r1_0, r1_1};
2023   dst_range.push_back(r1);
2024   dst_range.push_back(n1);
2025   dst_range.push_back(c1);
2026   dst_range.push_back(c0);
2027   return dst_range;
2028 }
InitInfo()2029 void FormatHelper::InitInfo() {
2030   info = {{kOpFormat_DEFAULT, FormatInfo(kOpFormat_DEFAULT, false)},
2031           {kOpFormat_NC1HWC0, FormatInfo(kOpFormat_NCHW, true)},
2032           {kOpFormat_ND, FormatInfo(kOpFormat_ND, false)},
2033           {kOpFormat_NCHW, FormatInfo(kOpFormat_NCHW, false)},
2034           {kOpFormat_NHWC, FormatInfo(kOpFormat_NHWC, false)},
2035           {kOpFormat_FRAC_NZ, FormatInfo(kOpFormat_ND, true)},
2036           {kOpFormat_FRAC_Z, FormatInfo(kOpFormat_NCHW, true)},
2037           {kOpFormat_NDHWC, FormatInfo(kOpFormat_NCDHW, false)},
2038           {kOpFormat_NCDHW, FormatInfo(kOpFormat_NCDHW, false)},
2039           {kOpFormat_NDC1HWC0, FormatInfo(kOpFormat_NCDHW, true)},
2040           {kOpFormat_FRACTAL_Z_3D, FormatInfo(kOpFormat_NCDHW, true)}};
2041 }
2042 
GetInstance()2043 FormatHelper &FormatHelper::GetInstance() noexcept {
2044   static FormatHelper instance{};
2045   return instance;
2046 }
2047 
GetBaseFormat(const std::string & format)2048 const std::string FormatHelper::GetBaseFormat(const std::string &format) {
2049   const auto &iter = info.find(format);
2050   if (iter != info.end()) {
2051     return iter->second.baseFormat;
2052   } else {
2053     return "";
2054   }
2055 }
2056 
IsBaseFormatType(const std::string & format)2057 bool FormatHelper::IsBaseFormatType(const std::string &format) {
2058   const auto &iter = info.find(format);
2059   if (iter == info.end()) {
2060     return false;
2061   }
2062 
2063   return iter->first == iter->second.baseFormat;
2064 }
2065 
IsPadded(const std::string & format)2066 bool FormatHelper::IsPadded(const std::string &format) {
2067   auto itr = info.find(format);
2068   if (itr != info.end()) {
2069     return itr->second.isPadded;
2070   }
2071   MS_LOG(INFO) << "unknown format type:" << format;
2072   return true;
2073 }
2074 
Clear()2075 void FormatHelper::Clear() { info.clear(); }
2076 }  // namespace trans
2077 }  // namespace mindspore
2078