• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/trans.h"
17 #include <functional>
18 #include <numeric>
19 #include <utility>
20 #include <algorithm>
21 #include "utils/ms_utils.h"
22 #include "abstract/utils.h"
23 #include "backend/kernel_compiler/kernel.h"
24 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
25 #include "runtime/device/convert_tensor_utils.h"
26 #include "utils/convert_utils.h"
27 #include "utils/log_adapter.h"
28 #include "utils/utils.h"
29 
30 using mindspore::abstract::Shape;
31 namespace mindspore {
32 namespace trans {
33 const int b1 = 1;
34 const int b2 = 2;
35 const int b4 = 4;
36 const int b8 = 8;
SetData(size_t size,bool pad_zero,size_t src_idx,size_t dst_idx,const FormatArgs & args,void * result)37 inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
38   switch (size) {
39     case b1:
40       static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
41       break;
42     case b2:
43       static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
44       break;
45     case b4:
46       static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
47       break;
48     case b8:
49       static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
50       break;
51     default:
52       MS_LOG(EXCEPTION) << "Trans data not support size " << size;
53   }
54 }
55 
56 // greatest common divsor
Gcd(size_t a,size_t b)57 size_t Gcd(size_t a, size_t b) {
58   if (b == 0) {
59     return 0;
60   }
61   size_t c = b;
62   while (a % b != 0) {
63     c = a % b;
64     a = b;
65     b = c;
66   }
67   return c;
68 }
69 
70 // least common  multiple
Lcm(size_t a,size_t b)71 size_t Lcm(size_t a, size_t b) {
72   if (b == 0) {
73     return 0;
74   }
75   size_t ret = (a * b) / (Gcd(a, b));
76   return ret;
77 }
78 
79 template <typename T>
DivCeil(T n1,T n2)80 T DivCeil(T n1, T n2) {
81   if (n2 != 0) {
82     return (n1 + n2 - 1) / n2;
83   }
84   return 0;
85 }
86 
GetShapeSize(const std::vector<size_t> & shape)87 size_t GetShapeSize(const std::vector<size_t> &shape) {
88   return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
89 }
90 
91 enum class DataTypeTransMode {
92   FROM_FLOAT_TO_FLOAT16,
93   FROM_FLOAT_TO_INT32,
94   FROM_FLOAT16_TO_FLOAT,
95   FROM_FLOAT16_TO_INT32,
96   FROM_FLOAT16_TO_UINT8,
97   FROM_INT32_TO_FLOAT,
98   FROM_INT32_TO_FLOAT16,
99   FROM_INT32_TO_UINT8,
100   FROM_INT32_TO_INT8,
101   FROM_INT32_TO_INT64,
102   FROM_INT32_TO_BOOL,
103   FROM_UINT8_TO_FLOAT,
104   FROM_UINT8_TO_INT32,
105   FROM_UINT8_TO_FLOAT16,
106   FROM_INT8_TO_FLOAT,
107   FROM_INT8_TO_FLOAT16,
108   FROM_INT8_TO_INT32,
109   FROM_INT64_TO_INT32,
110   FROM_UINT16_TO_INT32,
111   FROM_BOOL_TO_FLOAT,
112   FROM_BOOL_TO_INT32,
113   FROM_BOOL_TO_UINT8,
114   FROM_BOOL_TO_FLOAT16,
115   FROM_FLOAT64_TO_FLOAT32,
116   FROM_FLOAT32_TO_FLOAT64
117 };
118 
119 const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
120   {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32},
121   {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64},
122   {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), DataTypeTransMode::FROM_FLOAT_TO_FLOAT16},
123   {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT_TO_INT32},
124   {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT16_TO_FLOAT},
125   {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT16_TO_INT32},
126   {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), DataTypeTransMode::FROM_FLOAT16_TO_UINT8},
127   {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), DataTypeTransMode::FROM_INT32_TO_FLOAT},
128   {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), DataTypeTransMode::FROM_INT32_TO_FLOAT16},
129   {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), DataTypeTransMode::FROM_INT32_TO_UINT8},
130   {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), DataTypeTransMode::FROM_INT32_TO_INT8},
131   {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt64), DataTypeTransMode::FROM_INT32_TO_INT64},
132   {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), DataTypeTransMode::FROM_INT32_TO_BOOL},
133   {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_UINT8_TO_FLOAT},
134   {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), DataTypeTransMode::FROM_UINT8_TO_INT32},
135   {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_UINT8_TO_FLOAT16},
136   {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_INT8_TO_FLOAT},
137   {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_INT8_TO_FLOAT16},
138   {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), DataTypeTransMode::FROM_INT8_TO_INT32},
139   {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), DataTypeTransMode::FROM_INT64_TO_INT32},
140   {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), DataTypeTransMode::FROM_UINT16_TO_INT32},
141   {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), DataTypeTransMode::FROM_BOOL_TO_INT32},
142   {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), DataTypeTransMode::FROM_BOOL_TO_FLOAT},
143   {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), DataTypeTransMode::FROM_BOOL_TO_UINT8},
144   {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), DataTypeTransMode::FROM_BOOL_TO_FLOAT16}};
145 
CheckMemSize(const TypeIdArgs & args)146 void CheckMemSize(const TypeIdArgs &args) {
147   auto src_type_size = abstract::TypeIdSize(args.host_data_type);
148   auto dst_type_size = abstract::TypeIdSize(args.device_data_type);
149   if (src_type_size < 1 || dst_type_size < 1) {
150     MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
151   }
152   if (args.data_size / src_type_size != args.host_shape_size) {
153     MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
154   }
155 }
156 
157 template <typename SrcT, typename DstT>
TransDataSrc2Dst(const TypeIdArgs & args,void * dst,const size_t data_size)158 void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
159   CheckMemSize(args);
160   for (size_t idx = 0; idx != data_size; idx++) {
161     SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
162     static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
163   }
164 }
165 
166 template <typename SrcT>
TransDataSrc2Fp16(const TypeIdArgs & args,void * dst,const size_t data_size)167 void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
168   CheckMemSize(args);
169   auto src_data = static_cast<const SrcT *>(args.data);
170   auto half_data = static_cast<float16 *>(dst);
171   for (size_t i = 0; i < data_size; i++) {
172     half_data[i] = float16(src_data[i]);
173   }
174 }
175 
CastKernel(const TypeIdArgs & args,void * dst,const size_t data_size,const DataTypeTransMode mode)176 bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
177   using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
178   const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
179     {DataTypeTransMode::FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
180     {DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
181     {DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
182     {DataTypeTransMode::FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
183     {DataTypeTransMode::FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
184     {DataTypeTransMode::FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
185     {DataTypeTransMode::FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
186     {DataTypeTransMode::FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
187     {DataTypeTransMode::FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
188     {DataTypeTransMode::FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
189     {DataTypeTransMode::FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
190     {DataTypeTransMode::FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
191     {DataTypeTransMode::FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
192     {DataTypeTransMode::FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
193     {DataTypeTransMode::FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
194     {DataTypeTransMode::FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
195     {DataTypeTransMode::FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
196     {DataTypeTransMode::FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
197     {DataTypeTransMode::FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
198     {DataTypeTransMode::FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
199     {DataTypeTransMode::FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
200     {DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
201     {DataTypeTransMode::FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
202 
203   if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
204     device::FloatToHalf(dst, args.data, data_size);
205     return true;
206   } else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
207     device::HalfToFloat(dst, args.data, data_size);
208     return true;
209   }
210   auto iter = cast_kernel_map.find(mode);
211   if (iter != cast_kernel_map.end()) {
212     iter->second(args, dst, data_size);
213     return true;
214   } else {
215     MS_LOG(ERROR) << "Unsupported datatype trans";
216     return false;
217   }
218 }
219 
220 namespace {
HasShapeDynamic(const std::vector<int64_t> & shape_list)221 bool HasShapeDynamic(const std::vector<int64_t> &shape_list) {
222   return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t shape) { return shape == Shape::SHP_ANY; });
223 }
224 
225 template <typename T>
CheckDims(const std::vector<T> & shape)226 bool CheckDims(const std::vector<T> &shape) {
227   if (shape.size() != kNchwDims) {
228     MS_LOG(ERROR) << "Host shape dims should be 4";
229     return false;
230   }
231   return true;
232 }
233 
NchwDeviceShape(const std::vector<size_t> & shape)234 std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
235   if (!CheckDims(shape)) {
236     MS_LOG(EXCEPTION) << "Check dims failed.";
237   }
238   return shape;
239 }
240 
NchwDeviceDynamicShape(const std::vector<int64_t> & shape)241 std::vector<int64_t> NchwDeviceDynamicShape(const std::vector<int64_t> &shape) {
242   if (!CheckDims(shape)) {
243     MS_LOG(EXCEPTION) << "Check dims failed.";
244   }
245   return shape;
246 }
247 
NhwcDeviceShape(const std::vector<size_t> & shape)248 std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
249   if (!CheckDims(shape)) {
250     MS_LOG(EXCEPTION) << "Ccheck dims failed.";
251   }
252   std::vector<size_t> device_shape;
253   device_shape.push_back(shape[kN]);
254   device_shape.push_back(shape[kH]);
255   device_shape.push_back(shape[kW]);
256   device_shape.push_back(shape[kC]);
257   return device_shape;
258 }
259 
NhwcDeviceDynamicShape(const std::vector<int64_t> & shape)260 std::vector<int64_t> NhwcDeviceDynamicShape(const std::vector<int64_t> &shape) {
261   if (!CheckDims(shape)) {
262     MS_LOG(EXCEPTION) << "Ccheck dims failed.";
263   }
264   std::vector<int64_t> device_shape;
265   device_shape.push_back(shape[kN]);
266   device_shape.push_back(shape[kH]);
267   device_shape.push_back(shape[kW]);
268   device_shape.push_back(shape[kC]);
269   return device_shape;
270 }
271 
HwchDeviceShape(const std::vector<size_t> & shape)272 std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
273   if (!CheckDims(shape)) {
274     MS_LOG(EXCEPTION) << "Check dims failed.";
275   }
276   std::vector<size_t> device_shape;
277   device_shape.push_back(shape[kH]);
278   device_shape.push_back(shape[kW]);
279   device_shape.push_back(shape[kC]);
280   device_shape.push_back(shape[kN]);
281   return device_shape;
282 }
283 
HwchDeviceDynamicShape(const std::vector<int64_t> & shape)284 std::vector<int64_t> HwchDeviceDynamicShape(const std::vector<int64_t> &shape) {
285   if (!CheckDims(shape)) {
286     MS_LOG(EXCEPTION) << "Check dims failed.";
287   }
288   std::vector<int64_t> device_shape;
289   device_shape.push_back(shape[kH]);
290   device_shape.push_back(shape[kW]);
291   device_shape.push_back(shape[kC]);
292   device_shape.push_back(shape[kN]);
293   return device_shape;
294 }
295 
FracZDeviceShape(const std::vector<size_t> & shape)296 std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
297   if (!CheckDims(shape)) {
298     MS_LOG(EXCEPTION) << "Check dims failed.";
299   }
300   std::vector<size_t> device_shape;
301   const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
302   const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
303   device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
304   device_shape.push_back(cout16 / kCubeSize);
305   device_shape.push_back(kCubeSize);
306   device_shape.push_back(kCubeSize);
307   return device_shape;
308 }
309 
FracZDeviceDynamicShape(const std::vector<int64_t> & shape)310 std::vector<int64_t> FracZDeviceDynamicShape(const std::vector<int64_t> &shape) {
311   if (!CheckDims(shape)) {
312     MS_LOG(EXCEPTION) << "Check dims failed.";
313   }
314   auto tmp = SizeToLong(kCubeSize);
315   std::vector<int64_t> device_shape;
316   if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) {
317     device_shape.push_back(Shape::SHP_ANY);
318   } else {
319     const int64_t cin16 = ((shape[kC] + tmp - 1) / tmp) * tmp;
320     device_shape.push_back(shape[kH] * shape[kW] * cin16 / tmp);
321   }
322   if (shape[kN] == Shape::SHP_ANY) {
323     device_shape.push_back(Shape::SHP_ANY);
324   } else {
325     const int64_t cout16 = ((shape[kN] + tmp - 1) / tmp) * tmp;
326     device_shape.push_back(cout16 / tmp);
327   }
328   device_shape.push_back(tmp);
329   device_shape.push_back(tmp);
330   return device_shape;
331 }
332 
Nc1hwc0DeviceShape(const std::vector<size_t> & shape)333 std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
334   if (!CheckDims(shape)) {
335     MS_LOG(EXCEPTION) << "Check dims failed.";
336   }
337   std::vector<size_t> device_shape;
338   const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
339   const size_t C0 = kCubeSize;
340   device_shape.push_back(shape[kN]);
341   device_shape.push_back(C1);
342   device_shape.push_back(shape[kH]);
343   device_shape.push_back(shape[kW]);
344   device_shape.push_back(C0);
345   return device_shape;
346 }
347 
Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> & shape)348 std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
349   if (!CheckDims(shape)) {
350     MS_LOG(EXCEPTION) << "Check dims failed.";
351   }
352   std::vector<int64_t> device_shape;
353   auto tmp = SizeToLong(kCubeSize);
354   const int64_t C1 = (shape[kC] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[kC] + tmp - 1) / tmp;
355   const int64_t C0 = tmp;
356   device_shape.push_back(shape[kN]);
357   device_shape.push_back(C1);
358   device_shape.push_back(shape[kH]);
359   device_shape.push_back(shape[kW]);
360   device_shape.push_back(C0);
361   return device_shape;
362 }
363 
Ndc1hwc0DeviceShape(const std::vector<size_t> & shape)364 std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
365   // NCDHW
366   if (shape.size() != kNcdhw) {
367     MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
368   }
369   std::vector<size_t> device_shape;
370   const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
371   const size_t C0 = kCubeSize;
372   device_shape.push_back(shape[N_ncdhw]);
373   device_shape.push_back(shape[D_ncdhw]);
374   device_shape.push_back(C1);
375   device_shape.push_back(shape[H_ncdhw]);
376   device_shape.push_back(shape[W_ncdhw]);
377   device_shape.push_back(C0);
378   return device_shape;
379 }
380 
Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> & shape)381 std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
382   // NCDHW
383   if (shape.size() != kNcdhw) {
384     MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
385   }
386   auto tmp = SizeToLong(kCubeSize);
387   std::vector<int64_t> device_shape;
388   const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + tmp - 1) / tmp;
389   const int64_t C0 = tmp;
390   device_shape.push_back(shape[N_ncdhw]);
391   device_shape.push_back(shape[D_ncdhw]);
392   device_shape.push_back(C1);
393   device_shape.push_back(shape[H_ncdhw]);
394   device_shape.push_back(shape[W_ncdhw]);
395   device_shape.push_back(C0);
396   return device_shape;
397 }
398 
Fracz3DDeviceShape(const std::vector<size_t> & shape)399 std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
400   // NCDHW -> Frac_Z_3D
401   if (shape.size() != kNcdhw) {
402     MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
403   }
404   std::vector<size_t> device_shape;
405   const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
406   const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
407   device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
408   device_shape.push_back(N1);
409   device_shape.push_back(kCubeSize);
410   device_shape.push_back(kCubeSize);
411   return device_shape;
412 }
413 
Fracz3DDeviceDynamicShape(const std::vector<int64_t> & shape)414 std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
415   // NCDHW -> Frac_Z_3D
416   if (shape.size() != kNcdhw) {
417     MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
418   }
419   std::vector<int64_t> device_shape;
420   auto tmp = SizeToLong(kCubeSize);
421   if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
422     device_shape.push_back(Shape::SHP_ANY);
423   } else {
424     const int64_t C1 = (shape[1] + tmp - 1) / tmp;
425     device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
426   }
427 
428   const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + tmp - 1) / tmp;
429   device_shape.push_back(N1);
430   device_shape.push_back(tmp);
431   device_shape.push_back(tmp);
432   return device_shape;
433 }
434 
C1hwncoc0DeviceShape(const std::vector<size_t> & shape)435 std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
436   if (!CheckDims(shape)) {
437     MS_LOG(EXCEPTION) << "Check dims failed.";
438   }
439   std::vector<size_t> device_shape;
440   device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
441   device_shape.push_back(shape[kH]);
442   device_shape.push_back(shape[kW]);
443   device_shape.push_back(shape[kN]);
444   device_shape.push_back(kCubeSize);
445   device_shape.push_back(kCubeSize);
446   return device_shape;
447 }
448 
C1hwncoc0DeviceDynamicShape(const std::vector<int64_t> & shape)449 std::vector<int64_t> C1hwncoc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
450   if (!CheckDims(shape)) {
451     MS_LOG(EXCEPTION) << "Check dims failed.";
452   }
453   std::vector<int64_t> device_shape;
454   auto tmp = SizeToLong(kCubeSize);
455   shape[kC] == Shape::SHP_ANY ? device_shape.push_back(Shape::SHP_ANY)
456                               : device_shape.push_back((shape[kC] - 1) / tmp + 1);
457   device_shape.push_back(shape[kH]);
458   device_shape.push_back(shape[kW]);
459   device_shape.push_back(shape[kN]);
460   device_shape.push_back(tmp);
461   device_shape.push_back(tmp);
462   return device_shape;
463 }
464 
FracZc04DeviceShape(const std::vector<size_t> & shape)465 std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
466   if (!CheckDims(shape)) {
467     MS_LOG(EXCEPTION) << "Check dims failed.";
468   }
469   std::vector<size_t> device_shape;
470   const size_t c0 = 4;
471   auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
472   auto no = DivCeil(shape.at(kN), kCubeSize);
473   device_shape.push_back(first_dim);
474   device_shape.push_back(no);
475   device_shape.push_back(kCubeSize);
476   device_shape.push_back(kCubeSize);
477   return device_shape;
478 }
479 
FracZc04DeviceDynamicShape(const std::vector<int64_t> & shape)480 std::vector<int64_t> FracZc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
481   if (!CheckDims(shape)) {
482     MS_LOG(EXCEPTION) << "Check dims failed.";
483   }
484   std::vector<int64_t> device_shape;
485   const int64_t c0 = 4;
486 
487   int64_t first_dim;
488   if (HasShapeDynamic({shape[kH], shape[kW]})) {
489     first_dim = Shape::SHP_ANY;
490   } else {
491     first_dim = DivCeil(c0 * shape[kH] * shape[kW], SizeToLong(kCubeSize));
492   }
493   auto shape_kN = shape.at(kN);
494   int64_t no = (shape_kN == Shape::SHP_ANY) ? Shape::SHP_ANY : DivCeil(shape.at(kN), SizeToLong(kCubeSize));
495   device_shape.push_back(first_dim);
496   device_shape.push_back(no);
497   device_shape.push_back(SizeToLong(kCubeSize));
498   device_shape.push_back(SizeToLong(kCubeSize));
499   return device_shape;
500 }
501 
Nc1hwc04DeviceShape(const std::vector<size_t> & shape)502 std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
503   if (!CheckDims(shape)) {
504     MS_LOG(EXCEPTION) << "Check dims failed.";
505   }
506   std::vector<size_t> device_shape;
507   const size_t C1 = 1;
508   const size_t C0 = 4;
509   device_shape.push_back(shape[kN]);
510   device_shape.push_back(C1);
511   device_shape.push_back(shape[kH]);
512   device_shape.push_back(shape[kW]);
513   device_shape.push_back(C0);
514   return device_shape;
515 }
516 
Nc1hwc04DeviceDynamicShape(const std::vector<int64_t> & shape)517 std::vector<int64_t> Nc1hwc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
518   if (!CheckDims(shape)) {
519     MS_LOG(EXCEPTION) << "Check dims failed.";
520   }
521   std::vector<int64_t> device_shape;
522   const int64_t C1 = 1;
523   const int64_t C0 = 4;
524   device_shape.push_back(shape[kN]);
525   device_shape.push_back(C1);
526   device_shape.push_back(shape[kH]);
527   device_shape.push_back(shape[kW]);
528   device_shape.push_back(C0);
529   return device_shape;
530 }
531 
NcdhwDeviceShape(const std::vector<size_t> & shape)532 std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
533   if (shape.size() < kNcdhw) {
534     MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
535   }
536   return shape;
537 }
538 
NcdhwDeviceDynamicShape(const std::vector<int64_t> & shape)539 std::vector<int64_t> NcdhwDeviceDynamicShape(const std::vector<int64_t> &shape) {
540   if (shape.size() < kNcdhw) {
541     MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
542   }
543   return shape;
544 }
545 
546 // change channel-first shape to channel-last shape.
547 // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
ChannelLastDeviceShape(const std::vector<size_t> & shape)548 std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) {
549   auto dim = shape.size();
550   std::vector<size_t> axis;
551   axis.resize(dim);
552   int step_value = 2;
553   std::iota(axis.begin() + 1, axis.end(), step_value);
554   axis[dim - 1] = 1;
555 
556   std::vector<size_t> device_shape;
557   (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
558                        [&shape](size_t n) { return shape[n]; });
559 
560   return device_shape;
561 }
562 
563 // change channel-first shape to channel-last shape.
564 // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
ChannelLastDeviceDynamicShape(const std::vector<int64_t> & shape)565 std::vector<int64_t> ChannelLastDeviceDynamicShape(const std::vector<int64_t> &shape) {
566   auto dim = shape.size();
567   std::vector<int64_t> axis;
568   axis.resize(dim);
569   int step_value = 2;
570   std::iota(axis.begin() + 1, axis.end(), step_value);
571   axis[dim - 1] = 1;
572 
573   std::vector<int64_t> device_shape;
574   (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
575                        [&shape](size_t n) { return shape[n]; });
576 
577   return device_shape;
578 }
579 
FracZDeviceShapeWithGroups(const std::vector<size_t> & shape,const int64_t groups=1)580 std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, const int64_t groups = 1) {
581   if (!CheckDims(shape)) {
582     MS_LOG(EXCEPTION) << "Check dims failed.";
583   }
584   if (groups <= 0) {
585     MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
586   }
587   size_t group_size = LongToSize(groups);
588   size_t cin_ori = shape[kC];
589   size_t cout_ori = shape[kN] / group_size;
590   size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
591   size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
592   size_t c1_dim = cin_opt / kCubeSize;
593   size_t g_dim = DivCeil(group_size, e_mult);
594   size_t n1 = DivCeil(cout_ori * e_mult, kCubeSize);
595   std::vector<size_t> device_shape;
596   device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
597   device_shape.push_back(n1);
598   device_shape.push_back(kNiSize);
599   device_shape.push_back(kCubeSize);
600   return device_shape;
601 }
602 
FracZDeviceShapeWithGroups(const std::vector<int64_t> & shape,const int64_t groups=1)603 std::vector<int64_t> FracZDeviceShapeWithGroups(const std::vector<int64_t> &shape, const int64_t groups = 1) {
604   if (!CheckDims(shape)) {
605     MS_LOG(EXCEPTION) << "Check dims failed.";
606   }
607 
608   int64_t c1_dim = Shape::SHP_ANY;
609   int64_t g_dim = Shape::SHP_ANY;
610   int64_t n1 = Shape::SHP_ANY;
611   if (groups <= 0) {
612     MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
613   }
614   auto tmp = SizeToLong(kCubeSize);
615   if (!HasShapeDynamic({shape[kC], shape[kN]})) {
616     size_t group_size = LongToSize(groups);
617     size_t cin_ori_tmp = LongToSize(shape[kC]);
618     size_t cout_ori_tmp = LongToSize(shape[kN]) / group_size;
619     size_t e_mult =
620       std::min(Lcm(Lcm(cin_ori_tmp, kCubeSize) / cin_ori_tmp, Lcm(cout_ori_tmp, kCubeSize) / cout_ori_tmp), group_size);
621     int64_t cin_opt = SizeToLong(DivCeil(e_mult * cin_ori_tmp, kCubeSize) * kCubeSize);
622     c1_dim = cin_opt / tmp;
623     g_dim = SizeToLong(DivCeil(group_size, e_mult));
624     n1 = SizeToLong(DivCeil(cout_ori_tmp * e_mult, kCubeSize));
625   }
626 
627   std::vector<int64_t> device_shape;
628   if (!HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) {
629     device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
630   } else {
631     device_shape.push_back(Shape::SHP_ANY);
632   }
633   device_shape.push_back(n1);
634   device_shape.push_back(SizeToLong(kNiSize));
635   device_shape.push_back(tmp);
636   return device_shape;
637 }
638 
FracNZDeviceShape(const std::vector<size_t> & shape)639 std::vector<size_t> FracNZDeviceShape(const std::vector<size_t> &shape) {
640   if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
641     // For [1] and [1024] shape we can trait it as NZ shape
642     return shape;
643   }
644   std::vector<size_t> device_shape;
645   if (shape.size() < kShape2dDims) {
646     MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
647   } else {
648     const auto remove_dim = 2;
649     (void)std::copy(shape.begin(), shape.end() - remove_dim, std::back_inserter(device_shape));
650   }
651   auto h1 = (shape[shape.size() - kDim2] - 1) / kCubeSize + 1;
652   auto w1 = (shape[shape.size() - kDim1] - 1) / kCubeSize + 1;
653   device_shape.push_back(w1);
654   device_shape.push_back(h1);
655   device_shape.push_back(kCubeSize);
656   device_shape.push_back(kCubeSize);
657   return device_shape;
658 }
659 
FracNZDeviceDynamicShape(const std::vector<int64_t> & shape)660 std::vector<int64_t> FracNZDeviceDynamicShape(const std::vector<int64_t> &shape) {
661   std::vector<int64_t> device_shape;
662   if (shape.size() == 1 && (shape[0] == 1 || shape[0] % SizeToLong(kCubeSize) == 0)) {
663     // For [1] and [1024] shape we can trait it as NZ shape
664     return shape;
665   }
666   if (shape.size() < kShape2dDims) {
667     MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
668   } else {
669     (void)std::copy(shape.begin(), shape.end() - kDim2, std::back_inserter(device_shape));
670   }
671   int64_t h_shape = shape[shape.size() - kDim2];
672   int64_t w_shape = shape[shape.size() - kDim1];
673   int64_t h1 = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (h_shape - 1) / SizeToLong(kCubeSize) + 1;
674   int64_t w1 = (w_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (w_shape - 1) / SizeToLong(kCubeSize) + 1;
675   device_shape.push_back(w1);
676   device_shape.push_back(h1);
677   device_shape.push_back(kCubeSize);
678   device_shape.push_back(kCubeSize);
679   return device_shape;
680 }
681 
FracNZLSTMDeviceShape(const std::vector<size_t> & shape)682 std::vector<size_t> FracNZLSTMDeviceShape(const std::vector<size_t> &shape) {
683   const size_t c0 = 4;
684   const size_t h = shape.at(kN) / c0;
685   const size_t i = shape.at(kC) - h;
686   const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
687   const size_t second = c0 * DivCeil(h, kCubeSize);
688   std::vector<size_t> device_shape;
689   device_shape.push_back(first);
690   device_shape.push_back(second);
691   device_shape.push_back(kCubeSize);
692   device_shape.push_back(kCubeSize);
693   return device_shape;
694 }
695 
FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> & shape)696 std::vector<int64_t> FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> &shape) {
697   std::vector<int64_t> device_shape;
698   const int64_t c0 = 4;
699   const int64_t h_shape = shape.at(kN);
700   const int64_t i_shape = shape.at(kC);
701   const int64_t h = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : h_shape / c0;
702 
703   int64_t first = Shape::SHP_ANY;
704   if (h_shape != Shape::SHP_ANY && i_shape != Shape::SHP_ANY) {
705     int64_t i = i_shape - h;
706     first = DivCeil(i, SizeToLong(kCubeSize)) + DivCeil(h, SizeToLong(kCubeSize));
707   }
708   const int64_t second = (h == Shape::SHP_ANY) ? Shape::SHP_ANY : c0 * DivCeil(h, SizeToLong(kCubeSize));
709   device_shape.push_back(first);
710   device_shape.push_back(second);
711   device_shape.push_back(kCubeSize);
712   device_shape.push_back(kCubeSize);
713   return device_shape;
714 }
715 
FracZNRNNDeviceShape(const std::vector<size_t> & shape,const std::vector<int64_t> & input_hidden_size={kAlign16, kAlign16})716 std::vector<size_t> FracZNRNNDeviceShape(const std::vector<size_t> &shape,
717                                          const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
718   if (shape.size() < kShape2dDims) {
719     MS_LOG(EXCEPTION) << "Format FRACTAL_ZN_RNN don't support shape with " << shape.size() << " dims";
720   }
721   size_t input_size = LongToSize(input_hidden_size[0]);
722   size_t hidden_size = LongToSize(input_hidden_size[1]);
723   auto dim_last1 = shape[shape.size() - kDim1];
724   auto dim_last2 = shape[shape.size() - kDim2];
725   if (hidden_size == 0) {
726     MS_LOG(EXCEPTION) << "Hidden_size should not be 0.";
727   }
728   // cppcheck-suppress *
729   if (dim_last1 % hidden_size != 0) {
730     MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
731   }
732   size_t n_num = dim_last1 / hidden_size;
733   const size_t NUM16 = 16;
734   const size_t C0 = kCubeSize;
735 
736   std::vector<size_t> device_shape = shape;
737   if (dim_last2 == input_size || dim_last2 == hidden_size) {
738     device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
739   } else if (dim_last2 == input_size + hidden_size) {
740     device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
741   } else {
742     MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
743   }
744   device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
745   device_shape.push_back(NUM16);
746   device_shape.push_back(C0);
747   return device_shape;
748 }
749 
FracZNRNNDeviceDynamicShape(const std::vector<int64_t> & shape,const std::vector<int64_t> & input_hidden_size={kAlign16, kAlign16})750 std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &shape,
751                                                  const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
752   if (shape.size() < kShape2dDims) {
753     MS_LOG(EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
754   }
755   int64_t input_size = input_hidden_size[0];
756   int64_t hidden_size = input_hidden_size[1];
757   auto dim_last1 = shape[shape.size() - 1];
758   auto dim_last2 = shape[shape.size() - 2];
759   const int64_t NUM16 = 16;
760   const int64_t C0 = SizeToLong(kCubeSize);
761 
762   std::vector<int64_t> device_shape = shape;
763   if (dim_last2 == Shape::SHP_ANY) {
764     device_shape[shape.size() - kDim2] = Shape::SHP_ANY;
765   } else if (dim_last2 == input_size || dim_last2 == hidden_size) {
766     device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
767   } else if (dim_last2 == input_size + hidden_size) {
768     device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
769   } else {
770     MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
771   }
772   if (dim_last1 == Shape::SHP_ANY) {
773     device_shape[shape.size() - kDim1] = Shape::SHP_ANY;
774   } else {
775     if (dim_last1 % hidden_size != 0) {
776       MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
777     }
778     int64_t n_num = shape[shape.size() - 1] / hidden_size;
779     device_shape[shape.size() - kDim1] = n_num * DivCeil(hidden_size, C0);
780   }
781   device_shape.push_back(NUM16);
782   device_shape.push_back(C0);
783   return device_shape;
784 }
785 
NDRNNBiasDeviceShape(const std::vector<size_t> & shape,const int64_t hidden_size=16)786 std::vector<size_t> NDRNNBiasDeviceShape(const std::vector<size_t> &shape, const int64_t hidden_size = 16) {
787   if (shape.empty()) {
788     MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
789   }
790   if (hidden_size <= 0) {
791     MS_LOG(EXCEPTION) << "Hidden_size should be greater than 0, but got " << hidden_size;
792   }
793   size_t hid_size = LongToSize(hidden_size);
794   // cppcheck-suppress *
795   if (shape[shape.size() - 1] % hid_size != 0) {
796     MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hid_size;
797   }
798   size_t n_num = shape[shape.size() - 1] / hid_size;
799   const size_t C0 = kCubeSize;
800   std::vector<size_t> device_shape = shape;
801   device_shape[shape.size() - 1] = n_num * DivCeil(hid_size, C0) * C0;
802   return device_shape;
803 }
804 
NDRNNBiasDeviceDynamicShape(const std::vector<int64_t> & shape,const int64_t hidden_size=16)805 std::vector<int64_t> NDRNNBiasDeviceDynamicShape(const std::vector<int64_t> &shape, const int64_t hidden_size = 16) {
806   if (shape.empty()) {
807     MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
808   }
809   const int64_t C0 = SizeToLong(kCubeSize);
810   std::vector<int64_t> device_shape = shape;
811   // cppcheck-suppress *
812   auto dim_last1 = shape[shape.size() - 1];
813   if (dim_last1 == Shape::SHP_ANY) {
814     device_shape[shape.size() - 1] = Shape::SHP_ANY;
815   } else {
816     if (hidden_size <= 0) {
817       MS_LOG(EXCEPTION) << "Hidden_size should be greater than 0, but got " << hidden_size;
818     }
819     // cppcheck-suppress *
820     if (dim_last1 % hidden_size != 0) {
821       MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
822     }
823     int64_t n_num = shape[shape.size() - 1] / hidden_size;
824     device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
825   }
826   return device_shape;
827 }
828 }  // namespace
829 
GetAttrGroups(const AnfNodePtr & node,const size_t index)830 int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
831   if (node == nullptr) {
832     return 1;
833   }
834   if (node->isa<CNode>()) {
835     auto cnode = node->cast<CNodePtr>();
836     if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
837       auto node_name = AnfAlgo::GetCNodeName(cnode);
838       if (node_name == kAllReduceOpName || node_name == kBroadcastOpName) {
839         // if index not exists in fracz_group_idx, return default value 1
840         auto fz_group_idx = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrFracZGroupIdx);
841         int64_t out_index = SizeToLong(index);
842         auto fz_iter = std::find(std::begin(fz_group_idx), std::end(fz_group_idx), out_index);
843         if (fz_iter == std::end(fz_group_idx)) {
844           return 1;
845         }
846       }
847       return AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFracZGroup);
848     }
849   } else if (node->isa<Parameter>()) {
850     auto param = node->cast<ParameterPtr>();
851     MS_EXCEPTION_IF_NULL(param);
852     return param->fracz_group();
853   }
854   return 1;
855 }
856 
GetAttrInputAndHiddenSize(const AnfNodePtr & node)857 std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
858   MS_EXCEPTION_IF_NULL(node);
859   std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
860   if (!node->isa<CNode>()) {
861     return input_hidden_size;
862   }
863   auto cnode = node->cast<CNodePtr>();
864   MS_EXCEPTION_IF_NULL(cnode);
865   if (!AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) || !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
866     MS_LOG(EXCEPTION)
867       << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
868       << cnode->DebugString();
869   }
870   input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrInputSize);
871   input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrHiddenSize);
872   return input_hidden_size;
873 }
874 
IsNeedPadding(const std::string & format,const size_t shape_size)875 bool IsNeedPadding(const std::string &format, const size_t shape_size) {
876   if (shape_size == 0) {
877     return false;
878   }
879   if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW ||
880       kNoPaddingFormatSet.find(format) != kNoPaddingFormatSet.end()) {
881     return false;
882   } else if (shape_size < kNchwDims) {
883     return true;
884   }
885   return false;
886 }
887 
GetRuntimePaddingShape(const AnfNodePtr & node,size_t index)888 ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
889   MS_EXCEPTION_IF_NULL(node);
890   ShapeVector shape;
891   std::vector<size_t> host_shape;
892   if (node->isa<ValueNode>()) {
893     auto value_node = node->cast<ValueNodePtr>();
894     MS_EXCEPTION_IF_NULL(value_node);
895     auto node_value = value_node->value();
896     MS_EXCEPTION_IF_NULL(node_value);
897     auto tensor = node_value->cast<tensor::TensorPtr>();
898     if (tensor == nullptr) {
899       MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
900     }
901     auto shape_temp = tensor->shape();
902     (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), LongToSize);
903     if (host_shape.empty()) {
904       host_shape.push_back(1);
905     }
906   } else {
907     host_shape = AnfAlgo::GetOutputInferShape(node, index);
908   }
909   auto format = AnfAlgo::GetOutputFormat(node, index);
910   if (trans::IsNeedPadding(format, host_shape.size())) {
911     host_shape = trans::PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
912   }
913   std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
914   return shape;
915 }
916 
StringToAxisVector4D(const std::string & reshape_type_str,std::vector<Axis> * reshape_type_vec)917 void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
918   MS_EXCEPTION_IF_NULL(reshape_type_vec);
919   if (reshape_type_str.empty()) {
920     MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
921     return;
922   }
923   for (const auto &c : reshape_type_str) {
924     switch (c) {
925       case 'N':
926         reshape_type_vec->push_back(N);
927         break;
928       case 'C':
929         reshape_type_vec->push_back(C);
930         break;
931       case 'H':
932         reshape_type_vec->push_back(H);
933         break;
934       case 'W':
935         reshape_type_vec->push_back(W);
936         break;
937       default:
938         MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
939     }
940   }
941 }
StringToAxisVector5D(const std::string & reshape_type_str,std::vector<Axis5D> * reshape_type_vec)942 void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
943   MS_EXCEPTION_IF_NULL(reshape_type_vec);
944   if (reshape_type_str.empty()) {
945     MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
946     return;
947   }
948   for (const auto &c : reshape_type_str) {
949     switch (c) {
950       case 'N':
951         reshape_type_vec->push_back(N_ncdhw);
952         break;
953       case 'C':
954         reshape_type_vec->push_back(C_ncdhw);
955         break;
956       case 'D':
957         reshape_type_vec->push_back(D_ncdhw);
958         break;
959       case 'H':
960         reshape_type_vec->push_back(H_ncdhw);
961         break;
962       case 'W':
963         reshape_type_vec->push_back(W_ncdhw);
964         break;
965       default:
966         MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
967     }
968   }
969 }
970 
TransShapeToDevice(const std::vector<size_t> & shape,const std::string & format,const int64_t groups,const std::vector<int64_t> & input_hidden_size)971 std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
972                                        const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
973   using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
974   const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
975                                                                     {kOpFormat_NHWC, NhwcDeviceShape},
976                                                                     {kOpFormat_HWCN, HwchDeviceShape},
977                                                                     {kOpFormat_FRAC_Z, FracZDeviceShape},
978                                                                     {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
979                                                                     {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
980                                                                     {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
981                                                                     {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
982                                                                     {kOpFormat_NCDHW, NcdhwDeviceShape},
983                                                                     {kOpFormat_ChannelLast, ChannelLastDeviceShape},
984                                                                     {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
985                                                                     {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
986                                                                     {kOpFormat_FRAC_NZ, FracNZDeviceShape},
987                                                                     {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};
988 
989   if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
990     return shape;
991   }
992   if (groups > 1 && format == kOpFormat_FRAC_Z) {
993     return FracZDeviceShapeWithGroups(shape, groups);
994   }
995   if (format == kOpFormat_FRACTAL_ZN_RNN) {
996     return FracZNRNNDeviceShape(shape, input_hidden_size);
997   }
998   if (format == kOpFormat_ND_RNN_BIAS) {
999     return NDRNNBiasDeviceShape(shape, input_hidden_size[1]);
1000   }
1001   auto temp_shape = shape;
1002   if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
1003       shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
1004     MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
1005     temp_shape = PaddingShapeTo4dDefault(shape);
1006   }
1007   if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
1008     temp_shape = PaddingShapeTo5dDefault(shape);
1009   }
1010   auto iter = device_shape_map.find(format);
1011   if (iter == device_shape_map.end()) {
1012     MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
1013   }
1014   return iter->second(temp_shape);
1015 }
1016 
TransShapeToDevice(const std::vector<int64_t> & shape,const std::string & format,const int64_t groups,const std::vector<int64_t> & input_hidden_size)1017 std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
1018                                         const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
1019   using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>;
1020   const std::map<std::string, DeviceShapeTransfer> device_shape_map{
1021     {kOpFormat_NCHW, NchwDeviceDynamicShape},
1022     {kOpFormat_NHWC, NhwcDeviceDynamicShape},
1023     {kOpFormat_HWCN, HwchDeviceDynamicShape},
1024     {kOpFormat_FRAC_Z, FracZDeviceDynamicShape},
1025     {kOpFormat_NC1HWC0, Nc1hwc0DeviceDynamicShape},
1026     {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceDynamicShape},
1027     {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceDynamicShape},
1028     {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceDynamicShape},
1029     {kOpFormat_NCDHW, NcdhwDeviceDynamicShape},
1030     {kOpFormat_ChannelLast, ChannelLastDeviceDynamicShape},
1031     {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceDynamicShape},
1032     {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape},
1033     {kOpFormat_FRAC_NZ, FracNZDeviceDynamicShape},
1034     {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceDynamicShape}};
1035 
1036   if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
1037     return shape;
1038   }
1039   if (groups > 1 && format == kOpFormat_FRAC_Z) {
1040     return FracZDeviceShapeWithGroups(shape, groups);
1041   }
1042   if (format == kOpFormat_FRACTAL_ZN_RNN) {
1043     return FracZNRNNDeviceDynamicShape(shape, input_hidden_size);
1044   }
1045   if (format == kOpFormat_ND_RNN_BIAS) {
1046     return NDRNNBiasDeviceDynamicShape(shape, input_hidden_size[1]);
1047   }
1048   auto temp_shape = shape;
1049   if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
1050       shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
1051     MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
1052     temp_shape = PaddingShapeTo4dDefault(shape);
1053   }
1054   if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
1055     temp_shape = PaddingShapeTo5dDefault(shape);
1056   }
1057   auto iter = device_shape_map.find(format);
1058   if (iter == device_shape_map.end()) {
1059     MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
1060   }
1061   return iter->second(temp_shape);
1062 }
1063 
CheckArgs(const FormatArgs & args,size_t * size,size_t * total_size)1064 bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
1065   if (args.host_shape.size() != kNchwDims) {
1066     MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1067     return false;
1068   }
1069   MS_EXCEPTION_IF_NULL(size);
1070   MS_EXCEPTION_IF_NULL(total_size);
1071   *size = abstract::TypeIdSize(args.src_data_type);
1072   if (*size < 1) {
1073     MS_LOG(ERROR) << "Illegal dtype.";
1074     return false;
1075   }
1076   *total_size = abstract::ShapeSize(args.device_shape) * (*size);
1077   if (*total_size != args.device_size) {
1078     MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
1079     return false;
1080   }
1081   return true;
1082 }
1083 
TransDataType(const TypeIdArgs & args,void * result)1084 bool TransDataType(const TypeIdArgs &args, void *result) {
1085   MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
1086                 << TypeIdLabel(args.device_data_type);
1087   MS_EXCEPTION_IF_NULL(result);
1088   std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
1089   auto iter = mode_map.find(type_info);
1090   if (iter == mode_map.end()) {
1091     MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
1092                   << ", dst_type:" << TypeIdLabel(args.device_data_type);
1093     return false;
1094   }
1095   auto trans_mode = iter->second;
1096   if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
1097     MS_LOG(ERROR) << "Failed to trans datatype..";
1098     return false;
1099   }
1100   return true;
1101 }
1102 
TransFormat(const FormatArgs & args,void * result,int64_t groups)1103 bool TransFormat(const FormatArgs &args, void *result, int64_t groups) {
1104   MS_LOG(DEBUG) << "Start trans format.";
1105   if (abstract::TypeIdSize(args.src_data_type) < 1) {
1106     MS_LOG(ERROR) << "Invalid datatype..";
1107     return false;
1108   }
1109   if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
1110     return NchwTo4D(args, result);
1111   }
1112   if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
1113     return NchwToFracZWithGroups(args, result, groups);
1114   }
1115   auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
1116   if (iter == kTransFormatMapOfHostToDevice.end()) {
1117     MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
1118   }
1119   return iter->second(args, result);
1120 }
1121 
TransFormat(const FormatArgs & args,void * result,const AnfNodePtr & node,const size_t index)1122 bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
1123   int64_t groups = 1;
1124   if (args.device_format == kOpFormat_FRAC_Z) {
1125     groups = GetAttrGroups(node, index);
1126   }
1127   return TransFormat(args, result, groups);
1128 }
1129 
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,int64_t groups)1130 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
1131   const std::map<std::string, FormatTransfer> format_trans_map{
1132     {kOpFormat_FRAC_Z, FracZToNchw},         {kOpFormat_FRAC_NZ, FracNzToNchw},
1133     {kOpFormat_NC1HWC0, Nc1hwc0ToNchw},      {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
1134     {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw},
1135     {kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}};
1136 
1137   MS_LOG(DEBUG) << "Start trans format.";
1138   if (abstract::TypeIdSize(args.src_data_type) < 1) {
1139     MS_LOG(ERROR) << "Invalid datatype..";
1140     return false;
1141   }
1142   if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
1143     return ToNchw(args, result);
1144   }
1145   if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
1146     return FracZToNchwWithGroups(args, result, groups);
1147   }
1148   auto iter = format_trans_map.find(args.device_format);
1149   if (iter == format_trans_map.end()) {
1150     MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
1151   }
1152   return iter->second(args, result);
1153 }
1154 
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,const AnfNodePtr & node,const size_t index)1155 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
1156   int64_t groups = 1;
1157   if (args.device_format == kOpFormat_FRAC_Z) {
1158     groups = GetAttrGroups(node, index);
1159   }
1160   return TransFormatFromDeviceToHost(args, result, groups);
1161 }
1162 
NchwTo4D(const FormatArgs & args,void * result)1163 bool NchwTo4D(const FormatArgs &args, void *result) {
1164   // trans nchw to 4d
1165   MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
1166   MS_EXCEPTION_IF_NULL(result);
1167   size_t size = 0;
1168   size_t total_size = 0;
1169   if (!CheckArgs(args, &size, &total_size)) {
1170     MS_LOG(ERROR) << "Check args failed.";
1171     return false;
1172   }
1173   auto n = args.host_shape[kN];
1174   auto c = args.host_shape[kC];
1175   auto h = args.host_shape[kH];
1176   auto w = args.host_shape[kW];
1177   for (size_t ni = 0; ni < n; ni++) {
1178     for (size_t ci = 0; ci < c; ci++) {
1179       for (size_t hi = 0; hi < h; hi++) {
1180         for (size_t wi = 0; wi < w; wi++) {
1181           auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1182           size_t dst_idx = 0;
1183           if (args.device_format == kOpFormat_NHWC) {
1184             dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1185           } else if (args.device_format == kOpFormat_HWCN) {
1186             dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1187           }
1188           SetData(size, false, src_idx, dst_idx, args, result);
1189         }
1190       }
1191     }
1192   }
1193   return true;
1194 }
1195 
ToNchw(const FormatArgs & args,void * result)1196 bool ToNchw(const FormatArgs &args, void *result) {
1197   MS_LOG(DEBUG) << "Trans format to nchw from 4d.";
1198   MS_EXCEPTION_IF_NULL(result);
1199   size_t size = 0;
1200   size_t total_size = 0;
1201   if (!CheckArgs(args, &size, &total_size)) {
1202     MS_LOG(ERROR) << "Check args failed.";
1203     return false;
1204   }
1205   auto n = args.host_shape[kN];
1206   auto c = args.host_shape[kC];
1207   auto h = args.host_shape[kH];
1208   auto w = args.host_shape[kW];
1209   for (size_t ni = 0; ni < n; ni++) {
1210     for (size_t ci = 0; ci < c; ci++) {
1211       for (size_t hi = 0; hi < h; hi++) {
1212         for (size_t wi = 0; wi < w; wi++) {
1213           auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1214           size_t src_idx = 0;
1215           if (args.device_format == kOpFormat_NHWC) {
1216             src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1217           } else if (args.device_format == kOpFormat_HWCN) {
1218             src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1219           }
1220           SetData(size, false, src_idx, dst_idx, args, result);
1221         }
1222       }
1223     }
1224   }
1225   return true;
1226 }
1227 
NchwToFracZ(const FormatArgs & args,void * result)1228 bool NchwToFracZ(const FormatArgs &args, void *result) {
1229   MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
1230   MS_EXCEPTION_IF_NULL(result);
1231   if (args.host_shape.size() != kNchwDims) {
1232     MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1233     return false;
1234   }
1235   auto size = abstract::TypeIdSize(args.src_data_type);
1236   if (size < 1) {
1237     MS_LOG(ERROR) << "Illegal dtype.";
1238     return false;
1239   }
1240   auto n = args.host_shape[kN];
1241   auto c = args.host_shape[kC];
1242   auto h = args.host_shape[kH];
1243   auto w = args.host_shape[kW];
1244   const size_t c0 = 16;
1245   auto c1 = DivCeil(c, c0);
1246   auto hw = h * w;
1247   auto chw = c * hw;
1248   auto hwc0 = hw * c0;
1249   auto nchw = n * chw;
1250 
1251   auto hf_cnt = DivCeil(n, kCubeSize);
1252   auto vf_cnt = c1 * hw;
1253   auto fractal_ele_cnt = c0 * kCubeSize;
1254   auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
1255   auto dst_size = total_ele_cnt * size;
1256   if (dst_size != args.device_size) {
1257     MS_LOG(ERROR) << "Illegal total data size."
1258                   << "dst size is :" << dst_size << "device size is :" << args.device_size;
1259     return false;
1260   }
1261 
1262   for (size_t vfi = 0; vfi < vf_cnt; vfi++) {
1263     auto vf_base_i = vfi * hf_cnt;  // vertical fractal matrix base index
1264     for (size_t hfi = 0; hfi < hf_cnt; hfi++) {
1265       auto gfi = vf_base_i + hfi;  // global fractal matrix index
1266       auto src_n_offset = hfi * chw * kCubeSize;
1267       auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
1268       for (size_t row = 0; row < c0; row++) {
1269         auto src_ci = vfi / hw * c0 + row;
1270         auto src_row_offset = src_f_offset + row * hw;
1271         for (size_t col = 0; col < kCubeSize; col++) {
1272           auto src_ni = hfi * kCubeSize + col;
1273           auto src_idx = src_row_offset + chw * col;
1274           auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
1275           auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
1276           SetData(size, pad_zero, src_idx, dst_idx, args, result);
1277         }
1278       }
1279     }
1280   }
1281   return true;
1282 }
1283 
FracZToNchw(const FormatArgs & args,void * result)1284 bool FracZToNchw(const FormatArgs &args, void *result) {
1285   MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
1286   MS_EXCEPTION_IF_NULL(result);
1287   if (args.host_shape.size() != kNchwDims) {
1288     MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1289     return false;
1290   }
1291   auto size = abstract::TypeIdSize(args.src_data_type);
1292   if (size < 1) {
1293     MS_LOG(ERROR) << "Illegal dtype.";
1294     return false;
1295   }
1296   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1297   if (total_size != args.device_size) {
1298     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1299     return false;
1300   }
1301 
1302   auto n0 = args.device_shape.at(1);
1303   auto ni = args.device_shape.at(2);
1304   auto c0 = args.device_shape.at(3);
1305   auto n = args.host_shape[kN];
1306   auto c = args.host_shape[kC];
1307   auto h = args.host_shape[kH];
1308   auto w = args.host_shape[kW];
1309   auto nc = ni * n0;
1310   auto ncc0 = nc * c0;
1311   auto wncc0 = w * ncc0;
1312   auto hwncc0 = h * wncc0;
1313   auto hw = h * w;
1314   auto chw = c * hw;
1315 
1316   for (size_t n_idx = 0; n_idx < n; n_idx++) {
1317     size_t n_head_addr = n_idx * chw;
1318     for (size_t c_idx = 0; c_idx < c; c_idx++) {
1319       size_t c_head_addr = n_head_addr + c_idx * hw;
1320       for (size_t h_idx = 0; h_idx < h; h_idx++) {
1321         size_t h_head_addr = c_head_addr + h_idx * w;
1322         for (size_t w_idx = 0; w_idx < w; w_idx++) {
1323           size_t dst_idx = h_head_addr + w_idx;
1324           size_t c1_idx = c_idx / c0;
1325           size_t c0_idx = c_idx % c0;
1326           size_t nc_idx = n_idx;
1327           size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
1328           SetData(size, false, src_idx, dst_idx, args, result);
1329         }
1330       }
1331     }
1332   }
1333   return true;
1334 }
1335 
NchwToFracZc04(const FormatArgs & args,void * result)1336 bool NchwToFracZc04(const FormatArgs &args, void *result) {
1337   // trans nchw to FracZc04
1338   MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
1339   MS_EXCEPTION_IF_NULL(result);
1340   size_t size = 0;
1341   size_t total_size = 0;
1342   if (!CheckArgs(args, &size, &total_size)) {
1343     MS_LOG(ERROR) << "Check args failed.";
1344     return false;
1345   }
1346   auto cube = kCubeSize;
1347   auto n = args.host_shape[kN];
1348   auto c = args.host_shape[kC];
1349   auto h = args.host_shape[kH];
1350   auto w = args.host_shape[kW];
1351   const size_t c0 = 4;
1352   auto c1 = DivCeil(c, c0);
1353   auto hwc0 = h * w * c0;
1354   auto hwc = h * w * c;
1355   auto nhwc = n * h * w * c;
1356   auto n_cnt = DivCeil(n, cube);
1357   auto v_cnt = DivCeil(h * w * c0 * c1, cube);
1358   size_t dst_idx = 0;
1359 
1360   for (size_t vi = 0; vi < v_cnt; vi++) {
1361     for (size_t ni = 0; ni < n_cnt; ni++) {
1362       for (size_t col = 0; col < cube; col++) {
1363         for (size_t row = 0; row < cube; row++) {
1364           size_t cur_cube_n = cube * ni + col;
1365           size_t cur_cube_c1hwc0 = cube * vi + row;
1366           auto desc_g = cur_cube_n / n;
1367           auto desc_n = cur_cube_n % n;
1368           auto desc_c1 = cur_cube_c1hwc0 / hwc0;
1369           auto desc_c0 = cur_cube_c1hwc0 % c0;
1370           auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
1371           auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
1372           auto c_idx = desc_c1 * c0 + desc_c0;
1373           auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
1374           auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
1375           SetData(size, pad_zero, src_idx, dst_idx, args, result);
1376           dst_idx++;
1377         }
1378       }
1379     }
1380   }
1381   return true;
1382 }
1383 
NchwToNc1hwc04(const FormatArgs & args,void * result)1384 bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
1385   MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
1386   return NchwToNc1hwc0(args, result);
1387 }
1388 
Nc1hwc04ToNchw(const FormatArgs & args,void * result)1389 bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
1390   MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
1391   return Nc1hwc0ToNchw(args, result);
1392 }
1393 
TransShapeToNz(const std::vector<size_t> & host_shape,std::vector<size_t> * hw_shape)1394 bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
1395   MS_EXCEPTION_IF_NULL(hw_shape);
1396   if (host_shape.empty()) {
1397     MS_LOG(ERROR) << "Size of vector is 0.";
1398     return false;
1399   }
1400   switch (host_shape.size()) {
1401     case 1:
1402       hw_shape->push_back(1);
1403       hw_shape->push_back(1);
1404       hw_shape->push_back(host_shape[0]);
1405       return true;
1406     default:
1407       auto size = host_shape.size();
1408       if (size < kDim2) {
1409         MS_LOG(ERROR) << "Illegal size.";
1410         return false;
1411       }
1412       size_t times = 1;
1413       for (size_t i = 0; i != size - kDim2; i++) {
1414         times *= host_shape[i];
1415       }
1416       hw_shape->push_back(times);
1417       hw_shape->push_back(host_shape[size - kDim2]);
1418       hw_shape->push_back(host_shape[size - kDim1]);
1419       return true;
1420   }
1421 }
1422 
NchwToFracNz(const FormatArgs & args,void * result)1423 bool NchwToFracNz(const FormatArgs &args, void *result) {
1424   MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
1425   MS_EXCEPTION_IF_NULL(result);
1426   std::vector<size_t> hw_shape;
1427   if (!TransShapeToNz(args.host_shape, &hw_shape)) {
1428     MS_LOG(ERROR) << "Trans shape failed..";
1429     return false;
1430   }
1431   if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1432     MS_LOG(ERROR) << "Invalid shape size.";
1433     return false;
1434   }
1435   auto size = abstract::TypeIdSize(args.src_data_type);
1436   if (size < 1) {
1437     MS_LOG(ERROR) << "Illegal dtype";
1438     return false;
1439   }
1440 
1441   auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1442   if (dst_size != args.device_size) {
1443     MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1444     return false;
1445   }
1446   auto times = hw_shape.at(0);
1447   auto h = hw_shape.at(1);
1448   auto w = hw_shape.at(2);
1449   auto hw = h * w;
1450 
1451   auto shape_size = args.device_shape.size();
1452   auto w1 = args.device_shape[shape_size - 4];
1453   auto h1 = args.device_shape[shape_size - 3];
1454   auto h0 = args.device_shape[shape_size - 2];
1455   auto w0 = args.device_shape[shape_size - 1];
1456   auto h1h0w0 = h1 * h0 * w0;
1457   auto w1h1h0w0 = w1 * h1h0w0;
1458   auto num_w1 = w / w0;
1459 
1460   for (size_t times_idx = 0; times_idx < times; times_idx++) {
1461     auto times_head = times_idx * w1h1h0w0;
1462     auto src_times_head = times_idx * hw;
1463     for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1464       auto h1h0_head = times_head + h1h0_idx * w0;
1465       auto src_h_head = src_times_head + h1h0_idx * w;
1466       for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1467         for (size_t i = 0; i < w0; ++i) {
1468           size_t src_idx = src_h_head + w1_idx * w0 + i;
1469           size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
1470           SetData(size, false, src_idx, dst_idx, args, result);
1471         }
1472       }
1473       auto w1_head = num_w1 * w0;
1474       for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1475         auto src_w_idx = w1_head + w0_idx;
1476         size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1477         size_t src_idx = src_h_head + src_w_idx;
1478         SetData(size, false, src_idx, dst_idx, args, result);
1479       }
1480     }
1481   }
1482   return true;
1483 }
1484 
FracNzToNchw(const FormatArgs & args,void * result)1485 bool FracNzToNchw(const FormatArgs &args, void *result) {
1486   MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
1487   MS_EXCEPTION_IF_NULL(result);
1488   std::vector<size_t> hw_shape;
1489   if (!TransShapeToNz(args.host_shape, &hw_shape)) {
1490     MS_LOG(ERROR) << "Trans shape failed..";
1491     return false;
1492   }
1493   if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1494     MS_LOG(ERROR) << "Invalid shape size.";
1495     return false;
1496   }
1497   auto size = abstract::TypeIdSize(args.src_data_type);
1498   if (size < 1) {
1499     MS_LOG(ERROR) << "Illegal dtype";
1500     return false;
1501   }
1502 
1503   auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1504   if (dst_size != args.device_size) {
1505     MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1506     return false;
1507   }
1508   auto times = hw_shape.at(0);
1509   auto h = hw_shape.at(1);
1510   auto w = hw_shape.at(2);
1511   auto hw = h * w;
1512 
1513   auto shape_size = args.device_shape.size();
1514   auto w1 = args.device_shape[shape_size - 4];
1515   auto h1 = args.device_shape[shape_size - 3];
1516   auto h0 = args.device_shape[shape_size - 2];
1517   auto w0 = args.device_shape[shape_size - 1];
1518   auto h1h0w0 = h1 * h0 * w0;
1519   auto w1h1h0w0 = w1 * h1h0w0;
1520   auto num_w1 = w / w0;
1521 
1522   for (size_t times_idx = 0; times_idx < times; times_idx++) {
1523     auto times_head = times_idx * w1h1h0w0;
1524     auto src_times_head = times_idx * hw;
1525     for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1526       auto h1h0_head = times_head + h1h0_idx * w0;
1527       auto src_h_head = src_times_head + h1h0_idx * w;
1528       for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1529         for (size_t i = 0; i < w0; ++i) {
1530           size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
1531           size_t dst_idx = src_h_head + w1_idx * w0 + i;
1532           SetData(size, false, src_idx, dst_idx, args, result);
1533         }
1534       }
1535       auto w1_head = num_w1 * w0;
1536       for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1537         auto src_w_idx = w1_head + w0_idx;
1538         size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1539         size_t dst_idx = src_h_head + src_w_idx;
1540         SetData(size, false, src_idx, dst_idx, args, result);
1541       }
1542     }
1543   }
1544   return true;
1545 }
1546 
NchwToNc1hwc0(const FormatArgs & args,void * result)1547 bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
1548   MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
1549   MS_EXCEPTION_IF_NULL(result);
1550   if (args.host_shape.size() != kNchwDims) {
1551     MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1552     return false;
1553   }
1554   auto size = abstract::TypeIdSize(args.src_data_type);
1555   if (size < 1) {
1556     MS_LOG(ERROR) << "Illegal dtype.";
1557     return false;
1558   }
1559   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1560   if (total_size != args.device_size) {
1561     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1562     return false;
1563   }
1564 
1565   auto n = args.host_shape[kN];
1566   auto c = args.host_shape[kC];
1567   auto h = args.host_shape[kH];
1568   auto w = args.host_shape[kW];
1569   size_t c0 = kCubeSize;
1570   if (args.device_format == kOpFormat_NC1HWC0_C04) {
1571     c0 = kCubeSize_C04;
1572   }
1573   auto c1 = DivCeil(c, c0);
1574   auto hw = h * w;
1575   auto chw = c * hw;
1576   auto c1hwc0 = c1 * hw * c0;
1577   auto wc0 = w * c0;
1578 
1579   for (size_t n_idx = 0; n_idx < n; n_idx++) {
1580     size_t n_head_addr = n_idx * c1hwc0;
1581     for (size_t c1_idx = 0; c1_idx < c1; c1_idx++) {
1582       size_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
1583       for (size_t h_idx = 0; h_idx < h; h_idx++) {
1584         size_t h_head_addr = c1_head_addr + h_idx * wc0;
1585         for (size_t w_idx = 0; w_idx < w; w_idx++) {
1586           size_t w_head_addr = h_head_addr + w_idx * c0;
1587           for (size_t c0_idx = 0; c0_idx < c0; c0_idx++) {
1588             size_t dst_idx = c0_idx + w_head_addr;
1589             size_t c_idx = c0_idx + c1_idx * c0;
1590             size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
1591             auto pad_zero = c_idx >= c;
1592             SetData(size, pad_zero, src_idx, dst_idx, args, result);
1593           }
1594         }
1595       }
1596     }
1597   }
1598   return true;
1599 }
1600 
Nc1hwc0ToNchw(const FormatArgs & args,void * result)1601 bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
1602   MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
1603   MS_EXCEPTION_IF_NULL(result);
1604   if (args.host_shape.size() != kNchwDims) {
1605     MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1606     return false;
1607   }
1608   auto size = abstract::TypeIdSize(args.src_data_type);
1609   if (size < 1) {
1610     MS_LOG(ERROR) << "Illegal dtype.";
1611     return false;
1612   }
1613   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1614   if (total_size != args.device_size) {
1615     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1616     return false;
1617   }
1618 
1619   auto n = args.host_shape[kN];
1620   auto c = args.host_shape[kC];
1621   auto h = args.host_shape[kH];
1622   auto w = args.host_shape[kW];
1623   auto c1 = args.device_shape[1];
1624   auto c0 = args.device_shape[4];
1625 
1626   auto hw = h * w;
1627   auto chw = c * hw;
1628   auto wc0 = w * c0;
1629   auto hwc0 = h * wc0;
1630   auto c1hwc0 = c1 * hwc0;
1631 
1632   for (size_t n_idx = 0; n_idx < n; n_idx++) {
1633     size_t n_head_addr = n_idx * chw;
1634     for (size_t c_idx = 0; c_idx < c; c_idx++) {
1635       size_t c_head_addr = n_head_addr + c_idx * hw;
1636       for (size_t h_idx = 0; h_idx < h; h_idx++) {
1637         size_t h_head_addr = c_head_addr + h_idx * w;
1638         for (size_t w_idx = 0; w_idx < w; w_idx++) {
1639           size_t dst_idx = h_head_addr + w_idx;
1640           size_t c1_idx = c_idx / c0;
1641           size_t c0_idx = c_idx % c0;
1642           size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
1643           SetData(size, false, src_idx, dst_idx, args, result);
1644         }
1645       }
1646     }
1647   }
1648   return true;
1649 }
1650 
NchwToC1hwncoc0(const FormatArgs & args,void * result)1651 bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
1652   // trans nchw to c1hwncoc0
1653   MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
1654   MS_EXCEPTION_IF_NULL(result);
1655   size_t size = 0;
1656   size_t total_size = 0;
1657   if (!CheckArgs(args, &size, &total_size)) {
1658     MS_LOG(ERROR) << "Check args failed.";
1659     return false;
1660   }
1661   auto n = args.host_shape[kN];
1662   auto c = args.host_shape[kC];
1663   auto h = args.host_shape[kH];
1664   auto w = args.host_shape[kW];
1665   const int co_idx = 4;
1666   const int c0_idx = 5;
1667   auto c1 = args.device_shape[0];
1668   auto co = args.device_shape[co_idx];
1669   auto c0 = args.device_shape[c0_idx];
1670 
1671   for (size_t c1_i = 0; c1_i < c1; c1_i++) {
1672     for (size_t h_i = 0; h_i < h; h_i++) {
1673       for (size_t w_i = 0; w_i < w; w_i++) {
1674         for (size_t n_i = 0; n_i < n; n_i++) {
1675           for (size_t co_i = 0; co_i < co; co_i++) {
1676             for (size_t c0_i = 0; c0_i < c0; c0_i++) {
1677               size_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
1678                                co_i * c0 + c0_i;
1679               size_t c_i = c0_i + c1_i * c0;
1680               size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1681               auto pad_zero = !(c_i < c && c0_i == co_i);
1682               SetData(size, pad_zero, src_idx, dst_idx, args, result);
1683             }
1684           }
1685         }
1686       }
1687     }
1688   }
1689   return true;
1690 }
1691 
C1hwncoc0ToNchw(const FormatArgs & args,void * result)1692 bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
1693   // trans c1hwncoc0 to nchw
1694   MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
1695   MS_EXCEPTION_IF_NULL(result);
1696   size_t size = 0;
1697   size_t total_size = 0;
1698   if (!CheckArgs(args, &size, &total_size)) {
1699     MS_LOG(ERROR) << "Check args failed.";
1700     return false;
1701   }
1702   auto n = args.host_shape[kN];
1703   auto c = args.host_shape[kC];
1704   auto h = args.host_shape[kH];
1705   auto w = args.host_shape[kW];
1706   const int co_idx = 4;
1707   const int c0_idx = 5;
1708   auto co = args.device_shape[co_idx];
1709   auto c0 = args.device_shape[c0_idx];
1710   for (size_t n_i = 0; n_i < n; n_i++) {
1711     for (size_t c_i = 0; c_i < c; c_i++) {
1712       for (size_t h_i = 0; h_i < h; h_i++) {
1713         for (size_t w_i = 0; w_i < w; w_i++) {
1714           size_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1715           size_t c1_i = c_i / kCubeSize;
1716           size_t c0_i = c_i % kCubeSize;
1717           size_t co_i = c0_i;
1718           size_t src_idx =
1719             c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
1720           SetData(size, false, src_idx, dst_idx, args, result);
1721         }
1722       }
1723     }
1724   }
1725   return true;
1726 }
1727 
Ndc1hwc0ToNcdhw(const FormatArgs & args,void * result)1728 bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
1729   MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
1730   MS_EXCEPTION_IF_NULL(result);
1731 
1732   if (args.host_shape.size() != kNcdhw) {
1733     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1734     return false;
1735   }
1736   auto size = abstract::TypeIdSize(args.src_data_type);
1737   if (size < 1) {
1738     MS_LOG(ERROR) << "Illegal dtype.";
1739     return false;
1740   }
1741   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1742   if (total_size != args.device_size) {
1743     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1744     return false;
1745   }
1746   auto n = args.host_shape[N_ncdhw];
1747   auto c = args.host_shape[C_ncdhw];
1748   auto d = args.host_shape[D_ncdhw];
1749   auto h = args.host_shape[H_ncdhw];
1750   auto w = args.host_shape[W_ncdhw];
1751   auto c1 = args.device_shape[C1_ndc1hwc0];
1752   auto c0 = args.device_shape[C0_ndc1hwc0];
1753   const size_t cdhw = c * d * h * w;
1754   const size_t dhw = d * h * w;
1755   const size_t hw = h * w;
1756   const size_t dc1hwc0 = d * c1 * h * w * c0;
1757   const size_t c1hwc0 = c1 * h * w * c0;
1758   const size_t hwc0 = h * w * c0;
1759   const size_t wc0 = w * c0;
1760 
1761   for (size_t n_i = 0; n_i < n; n_i++) {
1762     size_t n_head = n_i * cdhw;
1763     for (size_t c_i = 0; c_i < c; c_i++) {
1764       size_t c_head = n_head + c_i * dhw;
1765       for (size_t d_i = 0; d_i < d; d_i++) {
1766         size_t d_head = c_head + d_i * hw;
1767         for (size_t h_i = 0; h_i < h; h_i++) {
1768           size_t h_head = d_head + h_i * w;
1769           for (size_t w_i = 0; w_i < w; w_i++) {
1770             size_t dst_i = h_head + w_i;
1771             size_t c1_i = c_i / c0;
1772             size_t c0_i = c_i % c0;
1773             auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
1774             SetData(size, false, src_idx, dst_i, args, result);
1775           }
1776         }
1777       }
1778     }
1779   }
1780   return true;
1781 }
1782 
NcdhwToNdc1hwc0(const FormatArgs & args,void * result)1783 bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
1784   MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
1785   MS_EXCEPTION_IF_NULL(result);
1786 
1787   if (args.host_shape.size() != kNcdhw) {
1788     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1789     return false;
1790   }
1791   auto size = abstract::TypeIdSize(args.src_data_type);
1792   if (size < 1) {
1793     MS_LOG(ERROR) << "Illegal dtype.";
1794     return false;
1795   }
1796   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1797   if (total_size != args.device_size) {
1798     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1799     return false;
1800   }
1801 
1802   auto n = args.host_shape[N_ncdhw];
1803   auto c = args.host_shape[C_ncdhw];
1804   auto d = args.host_shape[D_ncdhw];
1805   auto h = args.host_shape[H_ncdhw];
1806   auto w = args.host_shape[W_ncdhw];
1807   auto c0 = kCubeSize;
1808   auto c1 = DivCeil(c, c0);
1809   const size_t cdhw = c * d * h * w;
1810   const size_t dhw = d * h * w;
1811   const size_t hw = h * w;
1812   const size_t dc1hwc0 = d * c1 * h * w * c0;
1813   const size_t c1hwc0 = c1 * h * w * c0;
1814   const size_t hwc0 = h * w * c0;
1815   const size_t wc0 = w * c0;
1816 
1817   for (size_t n_i = 0; n_i < n; n_i++) {
1818     size_t n_head = n_i * dc1hwc0;
1819     for (size_t d_i = 0; d_i < d; d_i++) {
1820       size_t d_head = n_head + d_i * c1hwc0;
1821       for (size_t c1_i = 0; c1_i < c1; c1_i++) {
1822         size_t c1_head = d_head + c1_i * hwc0;
1823         for (size_t h_i = 0; h_i < h; h_i++) {
1824           size_t h_head = c1_head + h_i * wc0;
1825           for (size_t w_i = 0; w_i < w; w_i++) {
1826             size_t w_head = h_head + w_i * c0;
1827             for (size_t c0_i = 0; c0_i < c0; c0_i++) {
1828               size_t dst_i = c0_i + w_head;
1829               size_t c_i = c0_i + c1_i * c0;
1830               size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
1831               auto pad_zero = c_i >= c;
1832               SetData(size, pad_zero, src_i, dst_i, args, result);
1833             }
1834           }
1835         }
1836       }
1837     }
1838   }
1839   return true;
1840 }
1841 
NcdhwToFracZ3D(const FormatArgs & args,void * result)1842 bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
1843   MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
1844   MS_EXCEPTION_IF_NULL(result);
1845 
1846   if (args.host_shape.size() != kNcdhw) {
1847     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1848     return false;
1849   }
1850   auto size = abstract::TypeIdSize(args.src_data_type);
1851   if (size < 1) {
1852     MS_LOG(ERROR) << "Illegal dtype.";
1853     return false;
1854   }
1855   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1856   if (total_size != args.device_size) {
1857     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1858     return false;
1859   }
1860 
1861   auto n = args.host_shape[N_ncdhw];
1862   auto c = args.host_shape[C_ncdhw];
1863   auto d = args.host_shape[D_ncdhw];
1864   auto h = args.host_shape[H_ncdhw];
1865   auto w = args.host_shape[W_ncdhw];
1866 
1867   auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
1868   const size_t c0 = 16;
1869   auto c1 = DivCeil(c, c0);
1870   auto hw = h * w;
1871   auto dhw = d * hw;
1872   auto cdhw = c * dhw;
1873   auto n1n0c0 = n1n0 * c0;
1874   auto wn1n0c0 = w * n1n0c0;
1875   auto hwn1n0c0 = h * wn1n0c0;
1876   auto c1hwn1n0c0 = c1 * hwn1n0c0;
1877 
1878   for (size_t d_i = 0; d_i < d; d_i++) {
1879     for (size_t c1_i = 0; c1_i < c1; c1_i++) {
1880       for (size_t h_i = 0; h_i < h; h_i++) {
1881         for (size_t w_i = 0; w_i < w; w_i++) {
1882           for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
1883             for (size_t c0_i = 0; c0_i < c0; c0_i++) {
1884               auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
1885               // ncdhw
1886               size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
1887               auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
1888               SetData(size, pad_zero, src_i, dst_i, args, result);
1889             }
1890           }
1891         }
1892       }
1893     }
1894   }
1895   return true;
1896 }
1897 
FracZ3DToNcdhw(const FormatArgs & args,void * result)1898 bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
1899   MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
1900   MS_EXCEPTION_IF_NULL(result);
1901 
1902   if (args.host_shape.size() != kNcdhw) {
1903     MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1904     return false;
1905   }
1906   auto size = abstract::TypeIdSize(args.src_data_type);
1907   if (size < 1) {
1908     MS_LOG(ERROR) << "Illegal dtype.";
1909     return false;
1910   }
1911   auto total_size = abstract::ShapeSize(args.device_shape) * size;
1912   if (total_size != args.device_size) {
1913     MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1914     return false;
1915   }
1916   auto n = args.host_shape[N_ncdhw];
1917   auto c = args.host_shape[C_ncdhw];
1918   auto d = args.host_shape[D_ncdhw];
1919   auto h = args.host_shape[H_ncdhw];
1920   auto w = args.host_shape[W_ncdhw];
1921   const int kFZ3D_C0 = 3;
1922   auto c0 = args.device_shape[kFZ3D_C0];
1923   auto c1 = DivCeil(c, kCubeSize);
1924   auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
1925   auto n1n0c0 = n1n0 * c0;
1926   auto wn1n0c0 = w * n1n0c0;
1927   auto hwn1n0c0 = h * wn1n0c0;
1928   auto c1hwn1n0c0 = c1 * hwn1n0c0;
1929   auto hw = h * w;
1930   auto dhw = d * hw;
1931   auto cdhw = c * dhw;
1932 
1933   for (size_t n_i = 0; n_i < n; n_i++) {
1934     size_t n_head = n_i * cdhw;
1935     for (size_t c_i = 0; c_i < c; c_i++) {
1936       size_t c_head = n_head + c_i * dhw;
1937       for (size_t d_i = 0; d_i < d; d_i++) {
1938         size_t d_head = c_head + d_i * hw;
1939         for (size_t h_i = 0; h_i < h; h_i++) {
1940           size_t h_head = d_head + h_i * w;
1941           for (size_t w_i = 0; w_i < w; w_i++) {
1942             size_t dst_i = h_head + w_i;
1943             size_t c1_i = c_i / c0;
1944             size_t c0_i = c_i % c0;
1945             size_t nc_i = n_i;
1946             size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
1947             SetData(size, false, src_i, dst_i, args, result);
1948           }
1949         }
1950       }
1951     }
1952   }
1953   return true;
1954 }
1955 
CheckDimsAndDtypes(const FormatArgs & args,size_t * size)1956 bool CheckDimsAndDtypes(const FormatArgs &args, size_t *size) {
1957   if (args.host_shape.size() != kNchwDims) {
1958     MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1959     return false;
1960   }
1961   *size = abstract::TypeIdSize(args.src_data_type);
1962   if (*size < 1) {
1963     MS_LOG(ERROR) << "Illegal dtype";
1964     return false;
1965   }
1966   return true;
1967 }
1968 
NchwFracZTransWithGroups(const FormatArgs & args,void * result,bool to_device,int64_t groups)1969 bool NchwFracZTransWithGroups(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
1970   MS_EXCEPTION_IF_NULL(result);
1971   size_t size = 0;
1972   if (!(CheckDimsAndDtypes(args, &size))) {
1973     MS_LOG(ERROR) << "Illegal input args";
1974     return false;
1975   }
1976   if (groups <= 0) {
1977     MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
1978   }
1979   auto n_dim = args.host_shape[kN];
1980   auto c_dim = args.host_shape[kC];
1981   auto h_dim = args.host_shape[kH];
1982   auto w_dim = args.host_shape[kW];
1983   const size_t d_dim = 1;
1984   size_t group_size = LongToSize(groups);
1985   auto cin_ori = c_dim;
1986   auto cout_ori = n_dim / group_size;
1987   if (cin_ori == 0 || cout_ori == 0) {
1988     MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
1989     return false;
1990   }
1991   size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
1992   if (e_mult == 0) {
1993     MS_LOG(EXCEPTION) << "The value of e_mult should be greater than 0, but got " << e_mult;
1994   }
1995   size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
1996   size_t cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
1997   size_t c1_dim = cin_opt / kCubeSize;
1998   size_t dst_size = to_device ? GetShapeSize(args.device_shape) * size : GetShapeSize(args.host_shape) * size;
1999   if (dst_size == 0) {
2000     return true;
2001   }
2002   auto ret = memset_s(result, dst_size, 0, dst_size);
2003   if (ret != EOK) {
2004     MS_LOG(ERROR) << "memset failed";
2005     return false;
2006   }
2007   for (size_t g = 0; g < group_size; ++g) {
2008     for (size_t d = 0; d < d_dim; ++d) {
2009       for (size_t c = 0; c < c_dim; ++c) {
2010         for (size_t h = 0; h < h_dim; ++h) {
2011           for (size_t w = 0; w < w_dim; ++w) {
2012             for (size_t n = 0; n < cout_ori; ++n) {
2013               size_t e_val = g % e_mult;
2014               size_t dst_ci = e_val * cin_ori + c;
2015               size_t dst_co = e_val * cout_ori + n;
2016               size_t src_co = g * cout_ori + n;
2017               size_t temporary = dst_ci % kCubeSize;
2018               size_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
2019                                d * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
2020                                (dst_ci / kCubeSize) * h_dim * w_dim * cout_opt * kCubeSize +
2021                                h * w_dim * cout_opt * kCubeSize + w * cout_opt * kCubeSize + dst_co * kCubeSize +
2022                                temporary;
2023               size_t hst_idx =
2024                 src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
2025               if (to_device) {
2026                 SetData(size, false, hst_idx, dev_idx, args, result);
2027               } else {
2028                 SetData(size, false, dev_idx, hst_idx, args, result);
2029               }
2030             }
2031           }
2032         }
2033       }
2034     }
2035   }
2036   return true;
2037 }
2038 
NchwToFracZWithGroups(const FormatArgs & args,void * result,int64_t groups)2039 bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups) {
2040   MS_LOG(DEBUG) << "Trans format from nchw to frac_z with groups=" << groups;
2041   return NchwFracZTransWithGroups(args, result, true, groups);
2042 }
2043 
FracZToNchwWithGroups(const FormatArgs & args,void * result,int64_t groups)2044 bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups) {
2045   MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
2046   return NchwFracZTransWithGroups(args, result, false, groups);
2047 }
2048 }  // namespace trans
2049 }  // namespace mindspore
2050