1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/trans.h"
17 #include <functional>
18 #include <numeric>
19 #include <utility>
20 #include <algorithm>
21 #include "utils/ms_utils.h"
22 #include "abstract/utils.h"
23 #include "backend/kernel_compiler/kernel.h"
24 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
25 #include "runtime/device/convert_tensor_utils.h"
26 #include "utils/convert_utils.h"
27 #include "utils/log_adapter.h"
28 #include "utils/utils.h"
29
30 using mindspore::abstract::Shape;
31 namespace mindspore {
32 namespace trans {
33 const int b1 = 1;
34 const int b2 = 2;
35 const int b4 = 4;
36 const int b8 = 8;
SetData(size_t size,bool pad_zero,size_t src_idx,size_t dst_idx,const FormatArgs & args,void * result)37 inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
38 switch (size) {
39 case b1:
40 static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
41 break;
42 case b2:
43 static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
44 break;
45 case b4:
46 static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
47 break;
48 case b8:
49 static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
50 break;
51 default:
52 MS_LOG(EXCEPTION) << "Trans data not support size " << size;
53 }
54 }
55
56 // greatest common divsor
Gcd(size_t a,size_t b)57 size_t Gcd(size_t a, size_t b) {
58 if (b == 0) {
59 return 0;
60 }
61 size_t c = b;
62 while (a % b != 0) {
63 c = a % b;
64 a = b;
65 b = c;
66 }
67 return c;
68 }
69
70 // least common multiple
Lcm(size_t a,size_t b)71 size_t Lcm(size_t a, size_t b) {
72 if (b == 0) {
73 return 0;
74 }
75 size_t ret = (a * b) / (Gcd(a, b));
76 return ret;
77 }
78
79 template <typename T>
DivCeil(T n1,T n2)80 T DivCeil(T n1, T n2) {
81 if (n2 != 0) {
82 return (n1 + n2 - 1) / n2;
83 }
84 return 0;
85 }
86
GetShapeSize(const std::vector<size_t> & shape)87 size_t GetShapeSize(const std::vector<size_t> &shape) {
88 return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
89 }
90
91 enum class DataTypeTransMode {
92 FROM_FLOAT_TO_FLOAT16,
93 FROM_FLOAT_TO_INT32,
94 FROM_FLOAT16_TO_FLOAT,
95 FROM_FLOAT16_TO_INT32,
96 FROM_FLOAT16_TO_UINT8,
97 FROM_INT32_TO_FLOAT,
98 FROM_INT32_TO_FLOAT16,
99 FROM_INT32_TO_UINT8,
100 FROM_INT32_TO_INT8,
101 FROM_INT32_TO_INT64,
102 FROM_INT32_TO_BOOL,
103 FROM_UINT8_TO_FLOAT,
104 FROM_UINT8_TO_INT32,
105 FROM_UINT8_TO_FLOAT16,
106 FROM_INT8_TO_FLOAT,
107 FROM_INT8_TO_FLOAT16,
108 FROM_INT8_TO_INT32,
109 FROM_INT64_TO_INT32,
110 FROM_UINT16_TO_INT32,
111 FROM_BOOL_TO_FLOAT,
112 FROM_BOOL_TO_INT32,
113 FROM_BOOL_TO_UINT8,
114 FROM_BOOL_TO_FLOAT16,
115 FROM_FLOAT64_TO_FLOAT32,
116 FROM_FLOAT32_TO_FLOAT64
117 };
118
119 const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
120 {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32},
121 {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64},
122 {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), DataTypeTransMode::FROM_FLOAT_TO_FLOAT16},
123 {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT_TO_INT32},
124 {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT16_TO_FLOAT},
125 {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT16_TO_INT32},
126 {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), DataTypeTransMode::FROM_FLOAT16_TO_UINT8},
127 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), DataTypeTransMode::FROM_INT32_TO_FLOAT},
128 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), DataTypeTransMode::FROM_INT32_TO_FLOAT16},
129 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), DataTypeTransMode::FROM_INT32_TO_UINT8},
130 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), DataTypeTransMode::FROM_INT32_TO_INT8},
131 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt64), DataTypeTransMode::FROM_INT32_TO_INT64},
132 {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), DataTypeTransMode::FROM_INT32_TO_BOOL},
133 {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_UINT8_TO_FLOAT},
134 {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), DataTypeTransMode::FROM_UINT8_TO_INT32},
135 {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_UINT8_TO_FLOAT16},
136 {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_INT8_TO_FLOAT},
137 {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_INT8_TO_FLOAT16},
138 {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), DataTypeTransMode::FROM_INT8_TO_INT32},
139 {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), DataTypeTransMode::FROM_INT64_TO_INT32},
140 {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), DataTypeTransMode::FROM_UINT16_TO_INT32},
141 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), DataTypeTransMode::FROM_BOOL_TO_INT32},
142 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), DataTypeTransMode::FROM_BOOL_TO_FLOAT},
143 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), DataTypeTransMode::FROM_BOOL_TO_UINT8},
144 {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), DataTypeTransMode::FROM_BOOL_TO_FLOAT16}};
145
CheckMemSize(const TypeIdArgs & args)146 void CheckMemSize(const TypeIdArgs &args) {
147 auto src_type_size = abstract::TypeIdSize(args.host_data_type);
148 auto dst_type_size = abstract::TypeIdSize(args.device_data_type);
149 if (src_type_size < 1 || dst_type_size < 1) {
150 MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
151 }
152 if (args.data_size / src_type_size != args.host_shape_size) {
153 MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
154 }
155 }
156
157 template <typename SrcT, typename DstT>
TransDataSrc2Dst(const TypeIdArgs & args,void * dst,const size_t data_size)158 void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
159 CheckMemSize(args);
160 for (size_t idx = 0; idx != data_size; idx++) {
161 SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
162 static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
163 }
164 }
165
166 template <typename SrcT>
TransDataSrc2Fp16(const TypeIdArgs & args,void * dst,const size_t data_size)167 void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
168 CheckMemSize(args);
169 auto src_data = static_cast<const SrcT *>(args.data);
170 auto half_data = static_cast<float16 *>(dst);
171 for (size_t i = 0; i < data_size; i++) {
172 half_data[i] = float16(src_data[i]);
173 }
174 }
175
CastKernel(const TypeIdArgs & args,void * dst,const size_t data_size,const DataTypeTransMode mode)176 bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
177 using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
178 const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
179 {DataTypeTransMode::FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
180 {DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
181 {DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
182 {DataTypeTransMode::FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
183 {DataTypeTransMode::FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
184 {DataTypeTransMode::FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
185 {DataTypeTransMode::FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
186 {DataTypeTransMode::FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
187 {DataTypeTransMode::FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
188 {DataTypeTransMode::FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
189 {DataTypeTransMode::FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
190 {DataTypeTransMode::FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
191 {DataTypeTransMode::FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
192 {DataTypeTransMode::FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
193 {DataTypeTransMode::FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
194 {DataTypeTransMode::FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
195 {DataTypeTransMode::FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
196 {DataTypeTransMode::FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
197 {DataTypeTransMode::FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
198 {DataTypeTransMode::FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
199 {DataTypeTransMode::FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
200 {DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
201 {DataTypeTransMode::FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
202
203 if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
204 device::FloatToHalf(dst, args.data, data_size);
205 return true;
206 } else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
207 device::HalfToFloat(dst, args.data, data_size);
208 return true;
209 }
210 auto iter = cast_kernel_map.find(mode);
211 if (iter != cast_kernel_map.end()) {
212 iter->second(args, dst, data_size);
213 return true;
214 } else {
215 MS_LOG(ERROR) << "Unsupported datatype trans";
216 return false;
217 }
218 }
219
220 namespace {
HasShapeDynamic(const std::vector<int64_t> & shape_list)221 bool HasShapeDynamic(const std::vector<int64_t> &shape_list) {
222 return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t shape) { return shape == Shape::SHP_ANY; });
223 }
224
225 template <typename T>
CheckDims(const std::vector<T> & shape)226 bool CheckDims(const std::vector<T> &shape) {
227 if (shape.size() != kNchwDims) {
228 MS_LOG(ERROR) << "Host shape dims should be 4";
229 return false;
230 }
231 return true;
232 }
233
NchwDeviceShape(const std::vector<size_t> & shape)234 std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
235 if (!CheckDims(shape)) {
236 MS_LOG(EXCEPTION) << "Check dims failed.";
237 }
238 return shape;
239 }
240
NchwDeviceDynamicShape(const std::vector<int64_t> & shape)241 std::vector<int64_t> NchwDeviceDynamicShape(const std::vector<int64_t> &shape) {
242 if (!CheckDims(shape)) {
243 MS_LOG(EXCEPTION) << "Check dims failed.";
244 }
245 return shape;
246 }
247
NhwcDeviceShape(const std::vector<size_t> & shape)248 std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
249 if (!CheckDims(shape)) {
250 MS_LOG(EXCEPTION) << "Ccheck dims failed.";
251 }
252 std::vector<size_t> device_shape;
253 device_shape.push_back(shape[kN]);
254 device_shape.push_back(shape[kH]);
255 device_shape.push_back(shape[kW]);
256 device_shape.push_back(shape[kC]);
257 return device_shape;
258 }
259
NhwcDeviceDynamicShape(const std::vector<int64_t> & shape)260 std::vector<int64_t> NhwcDeviceDynamicShape(const std::vector<int64_t> &shape) {
261 if (!CheckDims(shape)) {
262 MS_LOG(EXCEPTION) << "Ccheck dims failed.";
263 }
264 std::vector<int64_t> device_shape;
265 device_shape.push_back(shape[kN]);
266 device_shape.push_back(shape[kH]);
267 device_shape.push_back(shape[kW]);
268 device_shape.push_back(shape[kC]);
269 return device_shape;
270 }
271
HwchDeviceShape(const std::vector<size_t> & shape)272 std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
273 if (!CheckDims(shape)) {
274 MS_LOG(EXCEPTION) << "Check dims failed.";
275 }
276 std::vector<size_t> device_shape;
277 device_shape.push_back(shape[kH]);
278 device_shape.push_back(shape[kW]);
279 device_shape.push_back(shape[kC]);
280 device_shape.push_back(shape[kN]);
281 return device_shape;
282 }
283
HwchDeviceDynamicShape(const std::vector<int64_t> & shape)284 std::vector<int64_t> HwchDeviceDynamicShape(const std::vector<int64_t> &shape) {
285 if (!CheckDims(shape)) {
286 MS_LOG(EXCEPTION) << "Check dims failed.";
287 }
288 std::vector<int64_t> device_shape;
289 device_shape.push_back(shape[kH]);
290 device_shape.push_back(shape[kW]);
291 device_shape.push_back(shape[kC]);
292 device_shape.push_back(shape[kN]);
293 return device_shape;
294 }
295
FracZDeviceShape(const std::vector<size_t> & shape)296 std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
297 if (!CheckDims(shape)) {
298 MS_LOG(EXCEPTION) << "Check dims failed.";
299 }
300 std::vector<size_t> device_shape;
301 const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
302 const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
303 device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
304 device_shape.push_back(cout16 / kCubeSize);
305 device_shape.push_back(kCubeSize);
306 device_shape.push_back(kCubeSize);
307 return device_shape;
308 }
309
FracZDeviceDynamicShape(const std::vector<int64_t> & shape)310 std::vector<int64_t> FracZDeviceDynamicShape(const std::vector<int64_t> &shape) {
311 if (!CheckDims(shape)) {
312 MS_LOG(EXCEPTION) << "Check dims failed.";
313 }
314 auto tmp = SizeToLong(kCubeSize);
315 std::vector<int64_t> device_shape;
316 if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) {
317 device_shape.push_back(Shape::SHP_ANY);
318 } else {
319 const int64_t cin16 = ((shape[kC] + tmp - 1) / tmp) * tmp;
320 device_shape.push_back(shape[kH] * shape[kW] * cin16 / tmp);
321 }
322 if (shape[kN] == Shape::SHP_ANY) {
323 device_shape.push_back(Shape::SHP_ANY);
324 } else {
325 const int64_t cout16 = ((shape[kN] + tmp - 1) / tmp) * tmp;
326 device_shape.push_back(cout16 / tmp);
327 }
328 device_shape.push_back(tmp);
329 device_shape.push_back(tmp);
330 return device_shape;
331 }
332
Nc1hwc0DeviceShape(const std::vector<size_t> & shape)333 std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
334 if (!CheckDims(shape)) {
335 MS_LOG(EXCEPTION) << "Check dims failed.";
336 }
337 std::vector<size_t> device_shape;
338 const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
339 const size_t C0 = kCubeSize;
340 device_shape.push_back(shape[kN]);
341 device_shape.push_back(C1);
342 device_shape.push_back(shape[kH]);
343 device_shape.push_back(shape[kW]);
344 device_shape.push_back(C0);
345 return device_shape;
346 }
347
Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> & shape)348 std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
349 if (!CheckDims(shape)) {
350 MS_LOG(EXCEPTION) << "Check dims failed.";
351 }
352 std::vector<int64_t> device_shape;
353 auto tmp = SizeToLong(kCubeSize);
354 const int64_t C1 = (shape[kC] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[kC] + tmp - 1) / tmp;
355 const int64_t C0 = tmp;
356 device_shape.push_back(shape[kN]);
357 device_shape.push_back(C1);
358 device_shape.push_back(shape[kH]);
359 device_shape.push_back(shape[kW]);
360 device_shape.push_back(C0);
361 return device_shape;
362 }
363
Ndc1hwc0DeviceShape(const std::vector<size_t> & shape)364 std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
365 // NCDHW
366 if (shape.size() != kNcdhw) {
367 MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
368 }
369 std::vector<size_t> device_shape;
370 const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
371 const size_t C0 = kCubeSize;
372 device_shape.push_back(shape[N_ncdhw]);
373 device_shape.push_back(shape[D_ncdhw]);
374 device_shape.push_back(C1);
375 device_shape.push_back(shape[H_ncdhw]);
376 device_shape.push_back(shape[W_ncdhw]);
377 device_shape.push_back(C0);
378 return device_shape;
379 }
380
Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> & shape)381 std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
382 // NCDHW
383 if (shape.size() != kNcdhw) {
384 MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
385 }
386 auto tmp = SizeToLong(kCubeSize);
387 std::vector<int64_t> device_shape;
388 const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + tmp - 1) / tmp;
389 const int64_t C0 = tmp;
390 device_shape.push_back(shape[N_ncdhw]);
391 device_shape.push_back(shape[D_ncdhw]);
392 device_shape.push_back(C1);
393 device_shape.push_back(shape[H_ncdhw]);
394 device_shape.push_back(shape[W_ncdhw]);
395 device_shape.push_back(C0);
396 return device_shape;
397 }
398
Fracz3DDeviceShape(const std::vector<size_t> & shape)399 std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
400 // NCDHW -> Frac_Z_3D
401 if (shape.size() != kNcdhw) {
402 MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
403 }
404 std::vector<size_t> device_shape;
405 const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
406 const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
407 device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
408 device_shape.push_back(N1);
409 device_shape.push_back(kCubeSize);
410 device_shape.push_back(kCubeSize);
411 return device_shape;
412 }
413
Fracz3DDeviceDynamicShape(const std::vector<int64_t> & shape)414 std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
415 // NCDHW -> Frac_Z_3D
416 if (shape.size() != kNcdhw) {
417 MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
418 }
419 std::vector<int64_t> device_shape;
420 auto tmp = SizeToLong(kCubeSize);
421 if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
422 device_shape.push_back(Shape::SHP_ANY);
423 } else {
424 const int64_t C1 = (shape[1] + tmp - 1) / tmp;
425 device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
426 }
427
428 const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + tmp - 1) / tmp;
429 device_shape.push_back(N1);
430 device_shape.push_back(tmp);
431 device_shape.push_back(tmp);
432 return device_shape;
433 }
434
C1hwncoc0DeviceShape(const std::vector<size_t> & shape)435 std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
436 if (!CheckDims(shape)) {
437 MS_LOG(EXCEPTION) << "Check dims failed.";
438 }
439 std::vector<size_t> device_shape;
440 device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
441 device_shape.push_back(shape[kH]);
442 device_shape.push_back(shape[kW]);
443 device_shape.push_back(shape[kN]);
444 device_shape.push_back(kCubeSize);
445 device_shape.push_back(kCubeSize);
446 return device_shape;
447 }
448
C1hwncoc0DeviceDynamicShape(const std::vector<int64_t> & shape)449 std::vector<int64_t> C1hwncoc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
450 if (!CheckDims(shape)) {
451 MS_LOG(EXCEPTION) << "Check dims failed.";
452 }
453 std::vector<int64_t> device_shape;
454 auto tmp = SizeToLong(kCubeSize);
455 shape[kC] == Shape::SHP_ANY ? device_shape.push_back(Shape::SHP_ANY)
456 : device_shape.push_back((shape[kC] - 1) / tmp + 1);
457 device_shape.push_back(shape[kH]);
458 device_shape.push_back(shape[kW]);
459 device_shape.push_back(shape[kN]);
460 device_shape.push_back(tmp);
461 device_shape.push_back(tmp);
462 return device_shape;
463 }
464
FracZc04DeviceShape(const std::vector<size_t> & shape)465 std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
466 if (!CheckDims(shape)) {
467 MS_LOG(EXCEPTION) << "Check dims failed.";
468 }
469 std::vector<size_t> device_shape;
470 const size_t c0 = 4;
471 auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
472 auto no = DivCeil(shape.at(kN), kCubeSize);
473 device_shape.push_back(first_dim);
474 device_shape.push_back(no);
475 device_shape.push_back(kCubeSize);
476 device_shape.push_back(kCubeSize);
477 return device_shape;
478 }
479
FracZc04DeviceDynamicShape(const std::vector<int64_t> & shape)480 std::vector<int64_t> FracZc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
481 if (!CheckDims(shape)) {
482 MS_LOG(EXCEPTION) << "Check dims failed.";
483 }
484 std::vector<int64_t> device_shape;
485 const int64_t c0 = 4;
486
487 int64_t first_dim;
488 if (HasShapeDynamic({shape[kH], shape[kW]})) {
489 first_dim = Shape::SHP_ANY;
490 } else {
491 first_dim = DivCeil(c0 * shape[kH] * shape[kW], SizeToLong(kCubeSize));
492 }
493 auto shape_kN = shape.at(kN);
494 int64_t no = (shape_kN == Shape::SHP_ANY) ? Shape::SHP_ANY : DivCeil(shape.at(kN), SizeToLong(kCubeSize));
495 device_shape.push_back(first_dim);
496 device_shape.push_back(no);
497 device_shape.push_back(SizeToLong(kCubeSize));
498 device_shape.push_back(SizeToLong(kCubeSize));
499 return device_shape;
500 }
501
Nc1hwc04DeviceShape(const std::vector<size_t> & shape)502 std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
503 if (!CheckDims(shape)) {
504 MS_LOG(EXCEPTION) << "Check dims failed.";
505 }
506 std::vector<size_t> device_shape;
507 const size_t C1 = 1;
508 const size_t C0 = 4;
509 device_shape.push_back(shape[kN]);
510 device_shape.push_back(C1);
511 device_shape.push_back(shape[kH]);
512 device_shape.push_back(shape[kW]);
513 device_shape.push_back(C0);
514 return device_shape;
515 }
516
Nc1hwc04DeviceDynamicShape(const std::vector<int64_t> & shape)517 std::vector<int64_t> Nc1hwc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
518 if (!CheckDims(shape)) {
519 MS_LOG(EXCEPTION) << "Check dims failed.";
520 }
521 std::vector<int64_t> device_shape;
522 const int64_t C1 = 1;
523 const int64_t C0 = 4;
524 device_shape.push_back(shape[kN]);
525 device_shape.push_back(C1);
526 device_shape.push_back(shape[kH]);
527 device_shape.push_back(shape[kW]);
528 device_shape.push_back(C0);
529 return device_shape;
530 }
531
NcdhwDeviceShape(const std::vector<size_t> & shape)532 std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
533 if (shape.size() < kNcdhw) {
534 MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
535 }
536 return shape;
537 }
538
NcdhwDeviceDynamicShape(const std::vector<int64_t> & shape)539 std::vector<int64_t> NcdhwDeviceDynamicShape(const std::vector<int64_t> &shape) {
540 if (shape.size() < kNcdhw) {
541 MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
542 }
543 return shape;
544 }
545
546 // change channel-first shape to channel-last shape.
547 // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
ChannelLastDeviceShape(const std::vector<size_t> & shape)548 std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) {
549 auto dim = shape.size();
550 std::vector<size_t> axis;
551 axis.resize(dim);
552 int step_value = 2;
553 std::iota(axis.begin() + 1, axis.end(), step_value);
554 axis[dim - 1] = 1;
555
556 std::vector<size_t> device_shape;
557 (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
558 [&shape](size_t n) { return shape[n]; });
559
560 return device_shape;
561 }
562
563 // change channel-first shape to channel-last shape.
564 // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
ChannelLastDeviceDynamicShape(const std::vector<int64_t> & shape)565 std::vector<int64_t> ChannelLastDeviceDynamicShape(const std::vector<int64_t> &shape) {
566 auto dim = shape.size();
567 std::vector<int64_t> axis;
568 axis.resize(dim);
569 int step_value = 2;
570 std::iota(axis.begin() + 1, axis.end(), step_value);
571 axis[dim - 1] = 1;
572
573 std::vector<int64_t> device_shape;
574 (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
575 [&shape](size_t n) { return shape[n]; });
576
577 return device_shape;
578 }
579
FracZDeviceShapeWithGroups(const std::vector<size_t> & shape,const int64_t groups=1)580 std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, const int64_t groups = 1) {
581 if (!CheckDims(shape)) {
582 MS_LOG(EXCEPTION) << "Check dims failed.";
583 }
584 if (groups <= 0) {
585 MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
586 }
587 size_t group_size = LongToSize(groups);
588 size_t cin_ori = shape[kC];
589 size_t cout_ori = shape[kN] / group_size;
590 size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
591 size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
592 size_t c1_dim = cin_opt / kCubeSize;
593 size_t g_dim = DivCeil(group_size, e_mult);
594 size_t n1 = DivCeil(cout_ori * e_mult, kCubeSize);
595 std::vector<size_t> device_shape;
596 device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
597 device_shape.push_back(n1);
598 device_shape.push_back(kNiSize);
599 device_shape.push_back(kCubeSize);
600 return device_shape;
601 }
602
FracZDeviceShapeWithGroups(const std::vector<int64_t> & shape,const int64_t groups=1)603 std::vector<int64_t> FracZDeviceShapeWithGroups(const std::vector<int64_t> &shape, const int64_t groups = 1) {
604 if (!CheckDims(shape)) {
605 MS_LOG(EXCEPTION) << "Check dims failed.";
606 }
607
608 int64_t c1_dim = Shape::SHP_ANY;
609 int64_t g_dim = Shape::SHP_ANY;
610 int64_t n1 = Shape::SHP_ANY;
611 if (groups <= 0) {
612 MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
613 }
614 auto tmp = SizeToLong(kCubeSize);
615 if (!HasShapeDynamic({shape[kC], shape[kN]})) {
616 size_t group_size = LongToSize(groups);
617 size_t cin_ori_tmp = LongToSize(shape[kC]);
618 size_t cout_ori_tmp = LongToSize(shape[kN]) / group_size;
619 size_t e_mult =
620 std::min(Lcm(Lcm(cin_ori_tmp, kCubeSize) / cin_ori_tmp, Lcm(cout_ori_tmp, kCubeSize) / cout_ori_tmp), group_size);
621 int64_t cin_opt = SizeToLong(DivCeil(e_mult * cin_ori_tmp, kCubeSize) * kCubeSize);
622 c1_dim = cin_opt / tmp;
623 g_dim = SizeToLong(DivCeil(group_size, e_mult));
624 n1 = SizeToLong(DivCeil(cout_ori_tmp * e_mult, kCubeSize));
625 }
626
627 std::vector<int64_t> device_shape;
628 if (!HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) {
629 device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
630 } else {
631 device_shape.push_back(Shape::SHP_ANY);
632 }
633 device_shape.push_back(n1);
634 device_shape.push_back(SizeToLong(kNiSize));
635 device_shape.push_back(tmp);
636 return device_shape;
637 }
638
FracNZDeviceShape(const std::vector<size_t> & shape)639 std::vector<size_t> FracNZDeviceShape(const std::vector<size_t> &shape) {
640 if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
641 // For [1] and [1024] shape we can trait it as NZ shape
642 return shape;
643 }
644 std::vector<size_t> device_shape;
645 if (shape.size() < kShape2dDims) {
646 MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
647 } else {
648 const auto remove_dim = 2;
649 (void)std::copy(shape.begin(), shape.end() - remove_dim, std::back_inserter(device_shape));
650 }
651 auto h1 = (shape[shape.size() - kDim2] - 1) / kCubeSize + 1;
652 auto w1 = (shape[shape.size() - kDim1] - 1) / kCubeSize + 1;
653 device_shape.push_back(w1);
654 device_shape.push_back(h1);
655 device_shape.push_back(kCubeSize);
656 device_shape.push_back(kCubeSize);
657 return device_shape;
658 }
659
FracNZDeviceDynamicShape(const std::vector<int64_t> & shape)660 std::vector<int64_t> FracNZDeviceDynamicShape(const std::vector<int64_t> &shape) {
661 std::vector<int64_t> device_shape;
662 if (shape.size() == 1 && (shape[0] == 1 || shape[0] % SizeToLong(kCubeSize) == 0)) {
663 // For [1] and [1024] shape we can trait it as NZ shape
664 return shape;
665 }
666 if (shape.size() < kShape2dDims) {
667 MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
668 } else {
669 (void)std::copy(shape.begin(), shape.end() - kDim2, std::back_inserter(device_shape));
670 }
671 int64_t h_shape = shape[shape.size() - kDim2];
672 int64_t w_shape = shape[shape.size() - kDim1];
673 int64_t h1 = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (h_shape - 1) / SizeToLong(kCubeSize) + 1;
674 int64_t w1 = (w_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (w_shape - 1) / SizeToLong(kCubeSize) + 1;
675 device_shape.push_back(w1);
676 device_shape.push_back(h1);
677 device_shape.push_back(kCubeSize);
678 device_shape.push_back(kCubeSize);
679 return device_shape;
680 }
681
FracNZLSTMDeviceShape(const std::vector<size_t> & shape)682 std::vector<size_t> FracNZLSTMDeviceShape(const std::vector<size_t> &shape) {
683 const size_t c0 = 4;
684 const size_t h = shape.at(kN) / c0;
685 const size_t i = shape.at(kC) - h;
686 const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
687 const size_t second = c0 * DivCeil(h, kCubeSize);
688 std::vector<size_t> device_shape;
689 device_shape.push_back(first);
690 device_shape.push_back(second);
691 device_shape.push_back(kCubeSize);
692 device_shape.push_back(kCubeSize);
693 return device_shape;
694 }
695
FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> & shape)696 std::vector<int64_t> FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> &shape) {
697 std::vector<int64_t> device_shape;
698 const int64_t c0 = 4;
699 const int64_t h_shape = shape.at(kN);
700 const int64_t i_shape = shape.at(kC);
701 const int64_t h = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : h_shape / c0;
702
703 int64_t first = Shape::SHP_ANY;
704 if (h_shape != Shape::SHP_ANY && i_shape != Shape::SHP_ANY) {
705 int64_t i = i_shape - h;
706 first = DivCeil(i, SizeToLong(kCubeSize)) + DivCeil(h, SizeToLong(kCubeSize));
707 }
708 const int64_t second = (h == Shape::SHP_ANY) ? Shape::SHP_ANY : c0 * DivCeil(h, SizeToLong(kCubeSize));
709 device_shape.push_back(first);
710 device_shape.push_back(second);
711 device_shape.push_back(kCubeSize);
712 device_shape.push_back(kCubeSize);
713 return device_shape;
714 }
715
FracZNRNNDeviceShape(const std::vector<size_t> & shape,const std::vector<int64_t> & input_hidden_size={kAlign16, kAlign16})716 std::vector<size_t> FracZNRNNDeviceShape(const std::vector<size_t> &shape,
717 const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
718 if (shape.size() < kShape2dDims) {
719 MS_LOG(EXCEPTION) << "Format FRACTAL_ZN_RNN don't support shape with " << shape.size() << " dims";
720 }
721 size_t input_size = LongToSize(input_hidden_size[0]);
722 size_t hidden_size = LongToSize(input_hidden_size[1]);
723 auto dim_last1 = shape[shape.size() - kDim1];
724 auto dim_last2 = shape[shape.size() - kDim2];
725 if (hidden_size == 0) {
726 MS_LOG(EXCEPTION) << "Hidden_size should not be 0.";
727 }
728 // cppcheck-suppress *
729 if (dim_last1 % hidden_size != 0) {
730 MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
731 }
732 size_t n_num = dim_last1 / hidden_size;
733 const size_t NUM16 = 16;
734 const size_t C0 = kCubeSize;
735
736 std::vector<size_t> device_shape = shape;
737 if (dim_last2 == input_size || dim_last2 == hidden_size) {
738 device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
739 } else if (dim_last2 == input_size + hidden_size) {
740 device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
741 } else {
742 MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
743 }
744 device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
745 device_shape.push_back(NUM16);
746 device_shape.push_back(C0);
747 return device_shape;
748 }
749
FracZNRNNDeviceDynamicShape(const std::vector<int64_t> & shape,const std::vector<int64_t> & input_hidden_size={kAlign16, kAlign16})750 std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &shape,
751 const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
752 if (shape.size() < kShape2dDims) {
753 MS_LOG(EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
754 }
755 int64_t input_size = input_hidden_size[0];
756 int64_t hidden_size = input_hidden_size[1];
757 auto dim_last1 = shape[shape.size() - 1];
758 auto dim_last2 = shape[shape.size() - 2];
759 const int64_t NUM16 = 16;
760 const int64_t C0 = SizeToLong(kCubeSize);
761
762 std::vector<int64_t> device_shape = shape;
763 if (dim_last2 == Shape::SHP_ANY) {
764 device_shape[shape.size() - kDim2] = Shape::SHP_ANY;
765 } else if (dim_last2 == input_size || dim_last2 == hidden_size) {
766 device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
767 } else if (dim_last2 == input_size + hidden_size) {
768 device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
769 } else {
770 MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
771 }
772 if (dim_last1 == Shape::SHP_ANY) {
773 device_shape[shape.size() - kDim1] = Shape::SHP_ANY;
774 } else {
775 if (dim_last1 % hidden_size != 0) {
776 MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
777 }
778 int64_t n_num = shape[shape.size() - 1] / hidden_size;
779 device_shape[shape.size() - kDim1] = n_num * DivCeil(hidden_size, C0);
780 }
781 device_shape.push_back(NUM16);
782 device_shape.push_back(C0);
783 return device_shape;
784 }
785
NDRNNBiasDeviceShape(const std::vector<size_t> & shape,const int64_t hidden_size=16)786 std::vector<size_t> NDRNNBiasDeviceShape(const std::vector<size_t> &shape, const int64_t hidden_size = 16) {
787 if (shape.empty()) {
788 MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
789 }
790 if (hidden_size <= 0) {
791 MS_LOG(EXCEPTION) << "Hidden_size should be greater than 0, but got " << hidden_size;
792 }
793 size_t hid_size = LongToSize(hidden_size);
794 // cppcheck-suppress *
795 if (shape[shape.size() - 1] % hid_size != 0) {
796 MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hid_size;
797 }
798 size_t n_num = shape[shape.size() - 1] / hid_size;
799 const size_t C0 = kCubeSize;
800 std::vector<size_t> device_shape = shape;
801 device_shape[shape.size() - 1] = n_num * DivCeil(hid_size, C0) * C0;
802 return device_shape;
803 }
804
NDRNNBiasDeviceDynamicShape(const std::vector<int64_t> & shape,const int64_t hidden_size=16)805 std::vector<int64_t> NDRNNBiasDeviceDynamicShape(const std::vector<int64_t> &shape, const int64_t hidden_size = 16) {
806 if (shape.empty()) {
807 MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
808 }
809 const int64_t C0 = SizeToLong(kCubeSize);
810 std::vector<int64_t> device_shape = shape;
811 // cppcheck-suppress *
812 auto dim_last1 = shape[shape.size() - 1];
813 if (dim_last1 == Shape::SHP_ANY) {
814 device_shape[shape.size() - 1] = Shape::SHP_ANY;
815 } else {
816 if (hidden_size <= 0) {
817 MS_LOG(EXCEPTION) << "Hidden_size should be greater than 0, but got " << hidden_size;
818 }
819 // cppcheck-suppress *
820 if (dim_last1 % hidden_size != 0) {
821 MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
822 }
823 int64_t n_num = shape[shape.size() - 1] / hidden_size;
824 device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
825 }
826 return device_shape;
827 }
828 } // namespace
829
GetAttrGroups(const AnfNodePtr & node,const size_t index)830 int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
831 if (node == nullptr) {
832 return 1;
833 }
834 if (node->isa<CNode>()) {
835 auto cnode = node->cast<CNodePtr>();
836 if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
837 auto node_name = AnfAlgo::GetCNodeName(cnode);
838 if (node_name == kAllReduceOpName || node_name == kBroadcastOpName) {
839 // if index not exists in fracz_group_idx, return default value 1
840 auto fz_group_idx = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrFracZGroupIdx);
841 int64_t out_index = SizeToLong(index);
842 auto fz_iter = std::find(std::begin(fz_group_idx), std::end(fz_group_idx), out_index);
843 if (fz_iter == std::end(fz_group_idx)) {
844 return 1;
845 }
846 }
847 return AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFracZGroup);
848 }
849 } else if (node->isa<Parameter>()) {
850 auto param = node->cast<ParameterPtr>();
851 MS_EXCEPTION_IF_NULL(param);
852 return param->fracz_group();
853 }
854 return 1;
855 }
856
GetAttrInputAndHiddenSize(const AnfNodePtr & node)857 std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
858 MS_EXCEPTION_IF_NULL(node);
859 std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
860 if (!node->isa<CNode>()) {
861 return input_hidden_size;
862 }
863 auto cnode = node->cast<CNodePtr>();
864 MS_EXCEPTION_IF_NULL(cnode);
865 if (!AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) || !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
866 MS_LOG(EXCEPTION)
867 << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
868 << cnode->DebugString();
869 }
870 input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrInputSize);
871 input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrHiddenSize);
872 return input_hidden_size;
873 }
874
IsNeedPadding(const std::string & format,const size_t shape_size)875 bool IsNeedPadding(const std::string &format, const size_t shape_size) {
876 if (shape_size == 0) {
877 return false;
878 }
879 if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW ||
880 kNoPaddingFormatSet.find(format) != kNoPaddingFormatSet.end()) {
881 return false;
882 } else if (shape_size < kNchwDims) {
883 return true;
884 }
885 return false;
886 }
887
GetRuntimePaddingShape(const AnfNodePtr & node,size_t index)888 ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
889 MS_EXCEPTION_IF_NULL(node);
890 ShapeVector shape;
891 std::vector<size_t> host_shape;
892 if (node->isa<ValueNode>()) {
893 auto value_node = node->cast<ValueNodePtr>();
894 MS_EXCEPTION_IF_NULL(value_node);
895 auto node_value = value_node->value();
896 MS_EXCEPTION_IF_NULL(node_value);
897 auto tensor = node_value->cast<tensor::TensorPtr>();
898 if (tensor == nullptr) {
899 MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
900 }
901 auto shape_temp = tensor->shape();
902 (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), LongToSize);
903 if (host_shape.empty()) {
904 host_shape.push_back(1);
905 }
906 } else {
907 host_shape = AnfAlgo::GetOutputInferShape(node, index);
908 }
909 auto format = AnfAlgo::GetOutputFormat(node, index);
910 if (trans::IsNeedPadding(format, host_shape.size())) {
911 host_shape = trans::PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
912 }
913 std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
914 return shape;
915 }
916
StringToAxisVector4D(const std::string & reshape_type_str,std::vector<Axis> * reshape_type_vec)917 void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
918 MS_EXCEPTION_IF_NULL(reshape_type_vec);
919 if (reshape_type_str.empty()) {
920 MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
921 return;
922 }
923 for (const auto &c : reshape_type_str) {
924 switch (c) {
925 case 'N':
926 reshape_type_vec->push_back(N);
927 break;
928 case 'C':
929 reshape_type_vec->push_back(C);
930 break;
931 case 'H':
932 reshape_type_vec->push_back(H);
933 break;
934 case 'W':
935 reshape_type_vec->push_back(W);
936 break;
937 default:
938 MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
939 }
940 }
941 }
StringToAxisVector5D(const std::string & reshape_type_str,std::vector<Axis5D> * reshape_type_vec)942 void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
943 MS_EXCEPTION_IF_NULL(reshape_type_vec);
944 if (reshape_type_str.empty()) {
945 MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
946 return;
947 }
948 for (const auto &c : reshape_type_str) {
949 switch (c) {
950 case 'N':
951 reshape_type_vec->push_back(N_ncdhw);
952 break;
953 case 'C':
954 reshape_type_vec->push_back(C_ncdhw);
955 break;
956 case 'D':
957 reshape_type_vec->push_back(D_ncdhw);
958 break;
959 case 'H':
960 reshape_type_vec->push_back(H_ncdhw);
961 break;
962 case 'W':
963 reshape_type_vec->push_back(W_ncdhw);
964 break;
965 default:
966 MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
967 }
968 }
969 }
970
TransShapeToDevice(const std::vector<size_t> & shape,const std::string & format,const int64_t groups,const std::vector<int64_t> & input_hidden_size)971 std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
972 const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
973 using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
974 const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
975 {kOpFormat_NHWC, NhwcDeviceShape},
976 {kOpFormat_HWCN, HwchDeviceShape},
977 {kOpFormat_FRAC_Z, FracZDeviceShape},
978 {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
979 {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
980 {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
981 {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
982 {kOpFormat_NCDHW, NcdhwDeviceShape},
983 {kOpFormat_ChannelLast, ChannelLastDeviceShape},
984 {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
985 {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
986 {kOpFormat_FRAC_NZ, FracNZDeviceShape},
987 {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};
988
989 if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
990 return shape;
991 }
992 if (groups > 1 && format == kOpFormat_FRAC_Z) {
993 return FracZDeviceShapeWithGroups(shape, groups);
994 }
995 if (format == kOpFormat_FRACTAL_ZN_RNN) {
996 return FracZNRNNDeviceShape(shape, input_hidden_size);
997 }
998 if (format == kOpFormat_ND_RNN_BIAS) {
999 return NDRNNBiasDeviceShape(shape, input_hidden_size[1]);
1000 }
1001 auto temp_shape = shape;
1002 if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
1003 shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
1004 MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
1005 temp_shape = PaddingShapeTo4dDefault(shape);
1006 }
1007 if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
1008 temp_shape = PaddingShapeTo5dDefault(shape);
1009 }
1010 auto iter = device_shape_map.find(format);
1011 if (iter == device_shape_map.end()) {
1012 MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
1013 }
1014 return iter->second(temp_shape);
1015 }
1016
TransShapeToDevice(const std::vector<int64_t> & shape,const std::string & format,const int64_t groups,const std::vector<int64_t> & input_hidden_size)1017 std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
1018 const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
1019 using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>;
1020 const std::map<std::string, DeviceShapeTransfer> device_shape_map{
1021 {kOpFormat_NCHW, NchwDeviceDynamicShape},
1022 {kOpFormat_NHWC, NhwcDeviceDynamicShape},
1023 {kOpFormat_HWCN, HwchDeviceDynamicShape},
1024 {kOpFormat_FRAC_Z, FracZDeviceDynamicShape},
1025 {kOpFormat_NC1HWC0, Nc1hwc0DeviceDynamicShape},
1026 {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceDynamicShape},
1027 {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceDynamicShape},
1028 {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceDynamicShape},
1029 {kOpFormat_NCDHW, NcdhwDeviceDynamicShape},
1030 {kOpFormat_ChannelLast, ChannelLastDeviceDynamicShape},
1031 {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceDynamicShape},
1032 {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape},
1033 {kOpFormat_FRAC_NZ, FracNZDeviceDynamicShape},
1034 {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceDynamicShape}};
1035
1036 if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
1037 return shape;
1038 }
1039 if (groups > 1 && format == kOpFormat_FRAC_Z) {
1040 return FracZDeviceShapeWithGroups(shape, groups);
1041 }
1042 if (format == kOpFormat_FRACTAL_ZN_RNN) {
1043 return FracZNRNNDeviceDynamicShape(shape, input_hidden_size);
1044 }
1045 if (format == kOpFormat_ND_RNN_BIAS) {
1046 return NDRNNBiasDeviceDynamicShape(shape, input_hidden_size[1]);
1047 }
1048 auto temp_shape = shape;
1049 if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
1050 shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
1051 MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
1052 temp_shape = PaddingShapeTo4dDefault(shape);
1053 }
1054 if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
1055 temp_shape = PaddingShapeTo5dDefault(shape);
1056 }
1057 auto iter = device_shape_map.find(format);
1058 if (iter == device_shape_map.end()) {
1059 MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
1060 }
1061 return iter->second(temp_shape);
1062 }
1063
CheckArgs(const FormatArgs & args,size_t * size,size_t * total_size)1064 bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
1065 if (args.host_shape.size() != kNchwDims) {
1066 MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1067 return false;
1068 }
1069 MS_EXCEPTION_IF_NULL(size);
1070 MS_EXCEPTION_IF_NULL(total_size);
1071 *size = abstract::TypeIdSize(args.src_data_type);
1072 if (*size < 1) {
1073 MS_LOG(ERROR) << "Illegal dtype.";
1074 return false;
1075 }
1076 *total_size = abstract::ShapeSize(args.device_shape) * (*size);
1077 if (*total_size != args.device_size) {
1078 MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
1079 return false;
1080 }
1081 return true;
1082 }
1083
TransDataType(const TypeIdArgs & args,void * result)1084 bool TransDataType(const TypeIdArgs &args, void *result) {
1085 MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
1086 << TypeIdLabel(args.device_data_type);
1087 MS_EXCEPTION_IF_NULL(result);
1088 std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
1089 auto iter = mode_map.find(type_info);
1090 if (iter == mode_map.end()) {
1091 MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
1092 << ", dst_type:" << TypeIdLabel(args.device_data_type);
1093 return false;
1094 }
1095 auto trans_mode = iter->second;
1096 if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
1097 MS_LOG(ERROR) << "Failed to trans datatype..";
1098 return false;
1099 }
1100 return true;
1101 }
1102
TransFormat(const FormatArgs & args,void * result,int64_t groups)1103 bool TransFormat(const FormatArgs &args, void *result, int64_t groups) {
1104 MS_LOG(DEBUG) << "Start trans format.";
1105 if (abstract::TypeIdSize(args.src_data_type) < 1) {
1106 MS_LOG(ERROR) << "Invalid datatype..";
1107 return false;
1108 }
1109 if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
1110 return NchwTo4D(args, result);
1111 }
1112 if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
1113 return NchwToFracZWithGroups(args, result, groups);
1114 }
1115 auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
1116 if (iter == kTransFormatMapOfHostToDevice.end()) {
1117 MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
1118 }
1119 return iter->second(args, result);
1120 }
1121
TransFormat(const FormatArgs & args,void * result,const AnfNodePtr & node,const size_t index)1122 bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
1123 int64_t groups = 1;
1124 if (args.device_format == kOpFormat_FRAC_Z) {
1125 groups = GetAttrGroups(node, index);
1126 }
1127 return TransFormat(args, result, groups);
1128 }
1129
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,int64_t groups)1130 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
1131 const std::map<std::string, FormatTransfer> format_trans_map{
1132 {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
1133 {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
1134 {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw},
1135 {kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}};
1136
1137 MS_LOG(DEBUG) << "Start trans format.";
1138 if (abstract::TypeIdSize(args.src_data_type) < 1) {
1139 MS_LOG(ERROR) << "Invalid datatype..";
1140 return false;
1141 }
1142 if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
1143 return ToNchw(args, result);
1144 }
1145 if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
1146 return FracZToNchwWithGroups(args, result, groups);
1147 }
1148 auto iter = format_trans_map.find(args.device_format);
1149 if (iter == format_trans_map.end()) {
1150 MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
1151 }
1152 return iter->second(args, result);
1153 }
1154
TransFormatFromDeviceToHost(const FormatArgs & args,void * result,const AnfNodePtr & node,const size_t index)1155 bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
1156 int64_t groups = 1;
1157 if (args.device_format == kOpFormat_FRAC_Z) {
1158 groups = GetAttrGroups(node, index);
1159 }
1160 return TransFormatFromDeviceToHost(args, result, groups);
1161 }
1162
NchwTo4D(const FormatArgs & args,void * result)1163 bool NchwTo4D(const FormatArgs &args, void *result) {
1164 // trans nchw to 4d
1165 MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
1166 MS_EXCEPTION_IF_NULL(result);
1167 size_t size = 0;
1168 size_t total_size = 0;
1169 if (!CheckArgs(args, &size, &total_size)) {
1170 MS_LOG(ERROR) << "Check args failed.";
1171 return false;
1172 }
1173 auto n = args.host_shape[kN];
1174 auto c = args.host_shape[kC];
1175 auto h = args.host_shape[kH];
1176 auto w = args.host_shape[kW];
1177 for (size_t ni = 0; ni < n; ni++) {
1178 for (size_t ci = 0; ci < c; ci++) {
1179 for (size_t hi = 0; hi < h; hi++) {
1180 for (size_t wi = 0; wi < w; wi++) {
1181 auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1182 size_t dst_idx = 0;
1183 if (args.device_format == kOpFormat_NHWC) {
1184 dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1185 } else if (args.device_format == kOpFormat_HWCN) {
1186 dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1187 }
1188 SetData(size, false, src_idx, dst_idx, args, result);
1189 }
1190 }
1191 }
1192 }
1193 return true;
1194 }
1195
ToNchw(const FormatArgs & args,void * result)1196 bool ToNchw(const FormatArgs &args, void *result) {
1197 MS_LOG(DEBUG) << "Trans format to nchw from 4d.";
1198 MS_EXCEPTION_IF_NULL(result);
1199 size_t size = 0;
1200 size_t total_size = 0;
1201 if (!CheckArgs(args, &size, &total_size)) {
1202 MS_LOG(ERROR) << "Check args failed.";
1203 return false;
1204 }
1205 auto n = args.host_shape[kN];
1206 auto c = args.host_shape[kC];
1207 auto h = args.host_shape[kH];
1208 auto w = args.host_shape[kW];
1209 for (size_t ni = 0; ni < n; ni++) {
1210 for (size_t ci = 0; ci < c; ci++) {
1211 for (size_t hi = 0; hi < h; hi++) {
1212 for (size_t wi = 0; wi < w; wi++) {
1213 auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
1214 size_t src_idx = 0;
1215 if (args.device_format == kOpFormat_NHWC) {
1216 src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
1217 } else if (args.device_format == kOpFormat_HWCN) {
1218 src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
1219 }
1220 SetData(size, false, src_idx, dst_idx, args, result);
1221 }
1222 }
1223 }
1224 }
1225 return true;
1226 }
1227
NchwToFracZ(const FormatArgs & args,void * result)1228 bool NchwToFracZ(const FormatArgs &args, void *result) {
1229 MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
1230 MS_EXCEPTION_IF_NULL(result);
1231 if (args.host_shape.size() != kNchwDims) {
1232 MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1233 return false;
1234 }
1235 auto size = abstract::TypeIdSize(args.src_data_type);
1236 if (size < 1) {
1237 MS_LOG(ERROR) << "Illegal dtype.";
1238 return false;
1239 }
1240 auto n = args.host_shape[kN];
1241 auto c = args.host_shape[kC];
1242 auto h = args.host_shape[kH];
1243 auto w = args.host_shape[kW];
1244 const size_t c0 = 16;
1245 auto c1 = DivCeil(c, c0);
1246 auto hw = h * w;
1247 auto chw = c * hw;
1248 auto hwc0 = hw * c0;
1249 auto nchw = n * chw;
1250
1251 auto hf_cnt = DivCeil(n, kCubeSize);
1252 auto vf_cnt = c1 * hw;
1253 auto fractal_ele_cnt = c0 * kCubeSize;
1254 auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
1255 auto dst_size = total_ele_cnt * size;
1256 if (dst_size != args.device_size) {
1257 MS_LOG(ERROR) << "Illegal total data size."
1258 << "dst size is :" << dst_size << "device size is :" << args.device_size;
1259 return false;
1260 }
1261
1262 for (size_t vfi = 0; vfi < vf_cnt; vfi++) {
1263 auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
1264 for (size_t hfi = 0; hfi < hf_cnt; hfi++) {
1265 auto gfi = vf_base_i + hfi; // global fractal matrix index
1266 auto src_n_offset = hfi * chw * kCubeSize;
1267 auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
1268 for (size_t row = 0; row < c0; row++) {
1269 auto src_ci = vfi / hw * c0 + row;
1270 auto src_row_offset = src_f_offset + row * hw;
1271 for (size_t col = 0; col < kCubeSize; col++) {
1272 auto src_ni = hfi * kCubeSize + col;
1273 auto src_idx = src_row_offset + chw * col;
1274 auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
1275 auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
1276 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1277 }
1278 }
1279 }
1280 }
1281 return true;
1282 }
1283
FracZToNchw(const FormatArgs & args,void * result)1284 bool FracZToNchw(const FormatArgs &args, void *result) {
1285 MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
1286 MS_EXCEPTION_IF_NULL(result);
1287 if (args.host_shape.size() != kNchwDims) {
1288 MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1289 return false;
1290 }
1291 auto size = abstract::TypeIdSize(args.src_data_type);
1292 if (size < 1) {
1293 MS_LOG(ERROR) << "Illegal dtype.";
1294 return false;
1295 }
1296 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1297 if (total_size != args.device_size) {
1298 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1299 return false;
1300 }
1301
1302 auto n0 = args.device_shape.at(1);
1303 auto ni = args.device_shape.at(2);
1304 auto c0 = args.device_shape.at(3);
1305 auto n = args.host_shape[kN];
1306 auto c = args.host_shape[kC];
1307 auto h = args.host_shape[kH];
1308 auto w = args.host_shape[kW];
1309 auto nc = ni * n0;
1310 auto ncc0 = nc * c0;
1311 auto wncc0 = w * ncc0;
1312 auto hwncc0 = h * wncc0;
1313 auto hw = h * w;
1314 auto chw = c * hw;
1315
1316 for (size_t n_idx = 0; n_idx < n; n_idx++) {
1317 size_t n_head_addr = n_idx * chw;
1318 for (size_t c_idx = 0; c_idx < c; c_idx++) {
1319 size_t c_head_addr = n_head_addr + c_idx * hw;
1320 for (size_t h_idx = 0; h_idx < h; h_idx++) {
1321 size_t h_head_addr = c_head_addr + h_idx * w;
1322 for (size_t w_idx = 0; w_idx < w; w_idx++) {
1323 size_t dst_idx = h_head_addr + w_idx;
1324 size_t c1_idx = c_idx / c0;
1325 size_t c0_idx = c_idx % c0;
1326 size_t nc_idx = n_idx;
1327 size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
1328 SetData(size, false, src_idx, dst_idx, args, result);
1329 }
1330 }
1331 }
1332 }
1333 return true;
1334 }
1335
NchwToFracZc04(const FormatArgs & args,void * result)1336 bool NchwToFracZc04(const FormatArgs &args, void *result) {
1337 // trans nchw to FracZc04
1338 MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
1339 MS_EXCEPTION_IF_NULL(result);
1340 size_t size = 0;
1341 size_t total_size = 0;
1342 if (!CheckArgs(args, &size, &total_size)) {
1343 MS_LOG(ERROR) << "Check args failed.";
1344 return false;
1345 }
1346 auto cube = kCubeSize;
1347 auto n = args.host_shape[kN];
1348 auto c = args.host_shape[kC];
1349 auto h = args.host_shape[kH];
1350 auto w = args.host_shape[kW];
1351 const size_t c0 = 4;
1352 auto c1 = DivCeil(c, c0);
1353 auto hwc0 = h * w * c0;
1354 auto hwc = h * w * c;
1355 auto nhwc = n * h * w * c;
1356 auto n_cnt = DivCeil(n, cube);
1357 auto v_cnt = DivCeil(h * w * c0 * c1, cube);
1358 size_t dst_idx = 0;
1359
1360 for (size_t vi = 0; vi < v_cnt; vi++) {
1361 for (size_t ni = 0; ni < n_cnt; ni++) {
1362 for (size_t col = 0; col < cube; col++) {
1363 for (size_t row = 0; row < cube; row++) {
1364 size_t cur_cube_n = cube * ni + col;
1365 size_t cur_cube_c1hwc0 = cube * vi + row;
1366 auto desc_g = cur_cube_n / n;
1367 auto desc_n = cur_cube_n % n;
1368 auto desc_c1 = cur_cube_c1hwc0 / hwc0;
1369 auto desc_c0 = cur_cube_c1hwc0 % c0;
1370 auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
1371 auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
1372 auto c_idx = desc_c1 * c0 + desc_c0;
1373 auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
1374 auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
1375 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1376 dst_idx++;
1377 }
1378 }
1379 }
1380 }
1381 return true;
1382 }
1383
NchwToNc1hwc04(const FormatArgs & args,void * result)1384 bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
1385 MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
1386 return NchwToNc1hwc0(args, result);
1387 }
1388
Nc1hwc04ToNchw(const FormatArgs & args,void * result)1389 bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
1390 MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
1391 return Nc1hwc0ToNchw(args, result);
1392 }
1393
TransShapeToNz(const std::vector<size_t> & host_shape,std::vector<size_t> * hw_shape)1394 bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
1395 MS_EXCEPTION_IF_NULL(hw_shape);
1396 if (host_shape.empty()) {
1397 MS_LOG(ERROR) << "Size of vector is 0.";
1398 return false;
1399 }
1400 switch (host_shape.size()) {
1401 case 1:
1402 hw_shape->push_back(1);
1403 hw_shape->push_back(1);
1404 hw_shape->push_back(host_shape[0]);
1405 return true;
1406 default:
1407 auto size = host_shape.size();
1408 if (size < kDim2) {
1409 MS_LOG(ERROR) << "Illegal size.";
1410 return false;
1411 }
1412 size_t times = 1;
1413 for (size_t i = 0; i != size - kDim2; i++) {
1414 times *= host_shape[i];
1415 }
1416 hw_shape->push_back(times);
1417 hw_shape->push_back(host_shape[size - kDim2]);
1418 hw_shape->push_back(host_shape[size - kDim1]);
1419 return true;
1420 }
1421 }
1422
NchwToFracNz(const FormatArgs & args,void * result)1423 bool NchwToFracNz(const FormatArgs &args, void *result) {
1424 MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
1425 MS_EXCEPTION_IF_NULL(result);
1426 std::vector<size_t> hw_shape;
1427 if (!TransShapeToNz(args.host_shape, &hw_shape)) {
1428 MS_LOG(ERROR) << "Trans shape failed..";
1429 return false;
1430 }
1431 if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1432 MS_LOG(ERROR) << "Invalid shape size.";
1433 return false;
1434 }
1435 auto size = abstract::TypeIdSize(args.src_data_type);
1436 if (size < 1) {
1437 MS_LOG(ERROR) << "Illegal dtype";
1438 return false;
1439 }
1440
1441 auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1442 if (dst_size != args.device_size) {
1443 MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1444 return false;
1445 }
1446 auto times = hw_shape.at(0);
1447 auto h = hw_shape.at(1);
1448 auto w = hw_shape.at(2);
1449 auto hw = h * w;
1450
1451 auto shape_size = args.device_shape.size();
1452 auto w1 = args.device_shape[shape_size - 4];
1453 auto h1 = args.device_shape[shape_size - 3];
1454 auto h0 = args.device_shape[shape_size - 2];
1455 auto w0 = args.device_shape[shape_size - 1];
1456 auto h1h0w0 = h1 * h0 * w0;
1457 auto w1h1h0w0 = w1 * h1h0w0;
1458 auto num_w1 = w / w0;
1459
1460 for (size_t times_idx = 0; times_idx < times; times_idx++) {
1461 auto times_head = times_idx * w1h1h0w0;
1462 auto src_times_head = times_idx * hw;
1463 for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1464 auto h1h0_head = times_head + h1h0_idx * w0;
1465 auto src_h_head = src_times_head + h1h0_idx * w;
1466 for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1467 for (size_t i = 0; i < w0; ++i) {
1468 size_t src_idx = src_h_head + w1_idx * w0 + i;
1469 size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
1470 SetData(size, false, src_idx, dst_idx, args, result);
1471 }
1472 }
1473 auto w1_head = num_w1 * w0;
1474 for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1475 auto src_w_idx = w1_head + w0_idx;
1476 size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1477 size_t src_idx = src_h_head + src_w_idx;
1478 SetData(size, false, src_idx, dst_idx, args, result);
1479 }
1480 }
1481 }
1482 return true;
1483 }
1484
FracNzToNchw(const FormatArgs & args,void * result)1485 bool FracNzToNchw(const FormatArgs &args, void *result) {
1486 MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
1487 MS_EXCEPTION_IF_NULL(result);
1488 std::vector<size_t> hw_shape;
1489 if (!TransShapeToNz(args.host_shape, &hw_shape)) {
1490 MS_LOG(ERROR) << "Trans shape failed..";
1491 return false;
1492 }
1493 if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
1494 MS_LOG(ERROR) << "Invalid shape size.";
1495 return false;
1496 }
1497 auto size = abstract::TypeIdSize(args.src_data_type);
1498 if (size < 1) {
1499 MS_LOG(ERROR) << "Illegal dtype";
1500 return false;
1501 }
1502
1503 auto dst_size = abstract::ShapeSize(args.device_shape) * size;
1504 if (dst_size != args.device_size) {
1505 MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
1506 return false;
1507 }
1508 auto times = hw_shape.at(0);
1509 auto h = hw_shape.at(1);
1510 auto w = hw_shape.at(2);
1511 auto hw = h * w;
1512
1513 auto shape_size = args.device_shape.size();
1514 auto w1 = args.device_shape[shape_size - 4];
1515 auto h1 = args.device_shape[shape_size - 3];
1516 auto h0 = args.device_shape[shape_size - 2];
1517 auto w0 = args.device_shape[shape_size - 1];
1518 auto h1h0w0 = h1 * h0 * w0;
1519 auto w1h1h0w0 = w1 * h1h0w0;
1520 auto num_w1 = w / w0;
1521
1522 for (size_t times_idx = 0; times_idx < times; times_idx++) {
1523 auto times_head = times_idx * w1h1h0w0;
1524 auto src_times_head = times_idx * hw;
1525 for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
1526 auto h1h0_head = times_head + h1h0_idx * w0;
1527 auto src_h_head = src_times_head + h1h0_idx * w;
1528 for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
1529 for (size_t i = 0; i < w0; ++i) {
1530 size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
1531 size_t dst_idx = src_h_head + w1_idx * w0 + i;
1532 SetData(size, false, src_idx, dst_idx, args, result);
1533 }
1534 }
1535 auto w1_head = num_w1 * w0;
1536 for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
1537 auto src_w_idx = w1_head + w0_idx;
1538 size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
1539 size_t dst_idx = src_h_head + src_w_idx;
1540 SetData(size, false, src_idx, dst_idx, args, result);
1541 }
1542 }
1543 }
1544 return true;
1545 }
1546
NchwToNc1hwc0(const FormatArgs & args,void * result)1547 bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
1548 MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
1549 MS_EXCEPTION_IF_NULL(result);
1550 if (args.host_shape.size() != kNchwDims) {
1551 MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1552 return false;
1553 }
1554 auto size = abstract::TypeIdSize(args.src_data_type);
1555 if (size < 1) {
1556 MS_LOG(ERROR) << "Illegal dtype.";
1557 return false;
1558 }
1559 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1560 if (total_size != args.device_size) {
1561 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1562 return false;
1563 }
1564
1565 auto n = args.host_shape[kN];
1566 auto c = args.host_shape[kC];
1567 auto h = args.host_shape[kH];
1568 auto w = args.host_shape[kW];
1569 size_t c0 = kCubeSize;
1570 if (args.device_format == kOpFormat_NC1HWC0_C04) {
1571 c0 = kCubeSize_C04;
1572 }
1573 auto c1 = DivCeil(c, c0);
1574 auto hw = h * w;
1575 auto chw = c * hw;
1576 auto c1hwc0 = c1 * hw * c0;
1577 auto wc0 = w * c0;
1578
1579 for (size_t n_idx = 0; n_idx < n; n_idx++) {
1580 size_t n_head_addr = n_idx * c1hwc0;
1581 for (size_t c1_idx = 0; c1_idx < c1; c1_idx++) {
1582 size_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
1583 for (size_t h_idx = 0; h_idx < h; h_idx++) {
1584 size_t h_head_addr = c1_head_addr + h_idx * wc0;
1585 for (size_t w_idx = 0; w_idx < w; w_idx++) {
1586 size_t w_head_addr = h_head_addr + w_idx * c0;
1587 for (size_t c0_idx = 0; c0_idx < c0; c0_idx++) {
1588 size_t dst_idx = c0_idx + w_head_addr;
1589 size_t c_idx = c0_idx + c1_idx * c0;
1590 size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
1591 auto pad_zero = c_idx >= c;
1592 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1593 }
1594 }
1595 }
1596 }
1597 }
1598 return true;
1599 }
1600
Nc1hwc0ToNchw(const FormatArgs & args,void * result)1601 bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
1602 MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
1603 MS_EXCEPTION_IF_NULL(result);
1604 if (args.host_shape.size() != kNchwDims) {
1605 MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1606 return false;
1607 }
1608 auto size = abstract::TypeIdSize(args.src_data_type);
1609 if (size < 1) {
1610 MS_LOG(ERROR) << "Illegal dtype.";
1611 return false;
1612 }
1613 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1614 if (total_size != args.device_size) {
1615 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1616 return false;
1617 }
1618
1619 auto n = args.host_shape[kN];
1620 auto c = args.host_shape[kC];
1621 auto h = args.host_shape[kH];
1622 auto w = args.host_shape[kW];
1623 auto c1 = args.device_shape[1];
1624 auto c0 = args.device_shape[4];
1625
1626 auto hw = h * w;
1627 auto chw = c * hw;
1628 auto wc0 = w * c0;
1629 auto hwc0 = h * wc0;
1630 auto c1hwc0 = c1 * hwc0;
1631
1632 for (size_t n_idx = 0; n_idx < n; n_idx++) {
1633 size_t n_head_addr = n_idx * chw;
1634 for (size_t c_idx = 0; c_idx < c; c_idx++) {
1635 size_t c_head_addr = n_head_addr + c_idx * hw;
1636 for (size_t h_idx = 0; h_idx < h; h_idx++) {
1637 size_t h_head_addr = c_head_addr + h_idx * w;
1638 for (size_t w_idx = 0; w_idx < w; w_idx++) {
1639 size_t dst_idx = h_head_addr + w_idx;
1640 size_t c1_idx = c_idx / c0;
1641 size_t c0_idx = c_idx % c0;
1642 size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
1643 SetData(size, false, src_idx, dst_idx, args, result);
1644 }
1645 }
1646 }
1647 }
1648 return true;
1649 }
1650
NchwToC1hwncoc0(const FormatArgs & args,void * result)1651 bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
1652 // trans nchw to c1hwncoc0
1653 MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
1654 MS_EXCEPTION_IF_NULL(result);
1655 size_t size = 0;
1656 size_t total_size = 0;
1657 if (!CheckArgs(args, &size, &total_size)) {
1658 MS_LOG(ERROR) << "Check args failed.";
1659 return false;
1660 }
1661 auto n = args.host_shape[kN];
1662 auto c = args.host_shape[kC];
1663 auto h = args.host_shape[kH];
1664 auto w = args.host_shape[kW];
1665 const int co_idx = 4;
1666 const int c0_idx = 5;
1667 auto c1 = args.device_shape[0];
1668 auto co = args.device_shape[co_idx];
1669 auto c0 = args.device_shape[c0_idx];
1670
1671 for (size_t c1_i = 0; c1_i < c1; c1_i++) {
1672 for (size_t h_i = 0; h_i < h; h_i++) {
1673 for (size_t w_i = 0; w_i < w; w_i++) {
1674 for (size_t n_i = 0; n_i < n; n_i++) {
1675 for (size_t co_i = 0; co_i < co; co_i++) {
1676 for (size_t c0_i = 0; c0_i < c0; c0_i++) {
1677 size_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
1678 co_i * c0 + c0_i;
1679 size_t c_i = c0_i + c1_i * c0;
1680 size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1681 auto pad_zero = !(c_i < c && c0_i == co_i);
1682 SetData(size, pad_zero, src_idx, dst_idx, args, result);
1683 }
1684 }
1685 }
1686 }
1687 }
1688 }
1689 return true;
1690 }
1691
C1hwncoc0ToNchw(const FormatArgs & args,void * result)1692 bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
1693 // trans c1hwncoc0 to nchw
1694 MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
1695 MS_EXCEPTION_IF_NULL(result);
1696 size_t size = 0;
1697 size_t total_size = 0;
1698 if (!CheckArgs(args, &size, &total_size)) {
1699 MS_LOG(ERROR) << "Check args failed.";
1700 return false;
1701 }
1702 auto n = args.host_shape[kN];
1703 auto c = args.host_shape[kC];
1704 auto h = args.host_shape[kH];
1705 auto w = args.host_shape[kW];
1706 const int co_idx = 4;
1707 const int c0_idx = 5;
1708 auto co = args.device_shape[co_idx];
1709 auto c0 = args.device_shape[c0_idx];
1710 for (size_t n_i = 0; n_i < n; n_i++) {
1711 for (size_t c_i = 0; c_i < c; c_i++) {
1712 for (size_t h_i = 0; h_i < h; h_i++) {
1713 for (size_t w_i = 0; w_i < w; w_i++) {
1714 size_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
1715 size_t c1_i = c_i / kCubeSize;
1716 size_t c0_i = c_i % kCubeSize;
1717 size_t co_i = c0_i;
1718 size_t src_idx =
1719 c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
1720 SetData(size, false, src_idx, dst_idx, args, result);
1721 }
1722 }
1723 }
1724 }
1725 return true;
1726 }
1727
Ndc1hwc0ToNcdhw(const FormatArgs & args,void * result)1728 bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
1729 MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
1730 MS_EXCEPTION_IF_NULL(result);
1731
1732 if (args.host_shape.size() != kNcdhw) {
1733 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1734 return false;
1735 }
1736 auto size = abstract::TypeIdSize(args.src_data_type);
1737 if (size < 1) {
1738 MS_LOG(ERROR) << "Illegal dtype.";
1739 return false;
1740 }
1741 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1742 if (total_size != args.device_size) {
1743 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1744 return false;
1745 }
1746 auto n = args.host_shape[N_ncdhw];
1747 auto c = args.host_shape[C_ncdhw];
1748 auto d = args.host_shape[D_ncdhw];
1749 auto h = args.host_shape[H_ncdhw];
1750 auto w = args.host_shape[W_ncdhw];
1751 auto c1 = args.device_shape[C1_ndc1hwc0];
1752 auto c0 = args.device_shape[C0_ndc1hwc0];
1753 const size_t cdhw = c * d * h * w;
1754 const size_t dhw = d * h * w;
1755 const size_t hw = h * w;
1756 const size_t dc1hwc0 = d * c1 * h * w * c0;
1757 const size_t c1hwc0 = c1 * h * w * c0;
1758 const size_t hwc0 = h * w * c0;
1759 const size_t wc0 = w * c0;
1760
1761 for (size_t n_i = 0; n_i < n; n_i++) {
1762 size_t n_head = n_i * cdhw;
1763 for (size_t c_i = 0; c_i < c; c_i++) {
1764 size_t c_head = n_head + c_i * dhw;
1765 for (size_t d_i = 0; d_i < d; d_i++) {
1766 size_t d_head = c_head + d_i * hw;
1767 for (size_t h_i = 0; h_i < h; h_i++) {
1768 size_t h_head = d_head + h_i * w;
1769 for (size_t w_i = 0; w_i < w; w_i++) {
1770 size_t dst_i = h_head + w_i;
1771 size_t c1_i = c_i / c0;
1772 size_t c0_i = c_i % c0;
1773 auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
1774 SetData(size, false, src_idx, dst_i, args, result);
1775 }
1776 }
1777 }
1778 }
1779 }
1780 return true;
1781 }
1782
NcdhwToNdc1hwc0(const FormatArgs & args,void * result)1783 bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
1784 MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
1785 MS_EXCEPTION_IF_NULL(result);
1786
1787 if (args.host_shape.size() != kNcdhw) {
1788 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1789 return false;
1790 }
1791 auto size = abstract::TypeIdSize(args.src_data_type);
1792 if (size < 1) {
1793 MS_LOG(ERROR) << "Illegal dtype.";
1794 return false;
1795 }
1796 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1797 if (total_size != args.device_size) {
1798 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1799 return false;
1800 }
1801
1802 auto n = args.host_shape[N_ncdhw];
1803 auto c = args.host_shape[C_ncdhw];
1804 auto d = args.host_shape[D_ncdhw];
1805 auto h = args.host_shape[H_ncdhw];
1806 auto w = args.host_shape[W_ncdhw];
1807 auto c0 = kCubeSize;
1808 auto c1 = DivCeil(c, c0);
1809 const size_t cdhw = c * d * h * w;
1810 const size_t dhw = d * h * w;
1811 const size_t hw = h * w;
1812 const size_t dc1hwc0 = d * c1 * h * w * c0;
1813 const size_t c1hwc0 = c1 * h * w * c0;
1814 const size_t hwc0 = h * w * c0;
1815 const size_t wc0 = w * c0;
1816
1817 for (size_t n_i = 0; n_i < n; n_i++) {
1818 size_t n_head = n_i * dc1hwc0;
1819 for (size_t d_i = 0; d_i < d; d_i++) {
1820 size_t d_head = n_head + d_i * c1hwc0;
1821 for (size_t c1_i = 0; c1_i < c1; c1_i++) {
1822 size_t c1_head = d_head + c1_i * hwc0;
1823 for (size_t h_i = 0; h_i < h; h_i++) {
1824 size_t h_head = c1_head + h_i * wc0;
1825 for (size_t w_i = 0; w_i < w; w_i++) {
1826 size_t w_head = h_head + w_i * c0;
1827 for (size_t c0_i = 0; c0_i < c0; c0_i++) {
1828 size_t dst_i = c0_i + w_head;
1829 size_t c_i = c0_i + c1_i * c0;
1830 size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
1831 auto pad_zero = c_i >= c;
1832 SetData(size, pad_zero, src_i, dst_i, args, result);
1833 }
1834 }
1835 }
1836 }
1837 }
1838 }
1839 return true;
1840 }
1841
NcdhwToFracZ3D(const FormatArgs & args,void * result)1842 bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
1843 MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
1844 MS_EXCEPTION_IF_NULL(result);
1845
1846 if (args.host_shape.size() != kNcdhw) {
1847 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1848 return false;
1849 }
1850 auto size = abstract::TypeIdSize(args.src_data_type);
1851 if (size < 1) {
1852 MS_LOG(ERROR) << "Illegal dtype.";
1853 return false;
1854 }
1855 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1856 if (total_size != args.device_size) {
1857 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1858 return false;
1859 }
1860
1861 auto n = args.host_shape[N_ncdhw];
1862 auto c = args.host_shape[C_ncdhw];
1863 auto d = args.host_shape[D_ncdhw];
1864 auto h = args.host_shape[H_ncdhw];
1865 auto w = args.host_shape[W_ncdhw];
1866
1867 auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
1868 const size_t c0 = 16;
1869 auto c1 = DivCeil(c, c0);
1870 auto hw = h * w;
1871 auto dhw = d * hw;
1872 auto cdhw = c * dhw;
1873 auto n1n0c0 = n1n0 * c0;
1874 auto wn1n0c0 = w * n1n0c0;
1875 auto hwn1n0c0 = h * wn1n0c0;
1876 auto c1hwn1n0c0 = c1 * hwn1n0c0;
1877
1878 for (size_t d_i = 0; d_i < d; d_i++) {
1879 for (size_t c1_i = 0; c1_i < c1; c1_i++) {
1880 for (size_t h_i = 0; h_i < h; h_i++) {
1881 for (size_t w_i = 0; w_i < w; w_i++) {
1882 for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
1883 for (size_t c0_i = 0; c0_i < c0; c0_i++) {
1884 auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
1885 // ncdhw
1886 size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
1887 auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
1888 SetData(size, pad_zero, src_i, dst_i, args, result);
1889 }
1890 }
1891 }
1892 }
1893 }
1894 }
1895 return true;
1896 }
1897
FracZ3DToNcdhw(const FormatArgs & args,void * result)1898 bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
1899 MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
1900 MS_EXCEPTION_IF_NULL(result);
1901
1902 if (args.host_shape.size() != kNcdhw) {
1903 MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
1904 return false;
1905 }
1906 auto size = abstract::TypeIdSize(args.src_data_type);
1907 if (size < 1) {
1908 MS_LOG(ERROR) << "Illegal dtype.";
1909 return false;
1910 }
1911 auto total_size = abstract::ShapeSize(args.device_shape) * size;
1912 if (total_size != args.device_size) {
1913 MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
1914 return false;
1915 }
1916 auto n = args.host_shape[N_ncdhw];
1917 auto c = args.host_shape[C_ncdhw];
1918 auto d = args.host_shape[D_ncdhw];
1919 auto h = args.host_shape[H_ncdhw];
1920 auto w = args.host_shape[W_ncdhw];
1921 const int kFZ3D_C0 = 3;
1922 auto c0 = args.device_shape[kFZ3D_C0];
1923 auto c1 = DivCeil(c, kCubeSize);
1924 auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
1925 auto n1n0c0 = n1n0 * c0;
1926 auto wn1n0c0 = w * n1n0c0;
1927 auto hwn1n0c0 = h * wn1n0c0;
1928 auto c1hwn1n0c0 = c1 * hwn1n0c0;
1929 auto hw = h * w;
1930 auto dhw = d * hw;
1931 auto cdhw = c * dhw;
1932
1933 for (size_t n_i = 0; n_i < n; n_i++) {
1934 size_t n_head = n_i * cdhw;
1935 for (size_t c_i = 0; c_i < c; c_i++) {
1936 size_t c_head = n_head + c_i * dhw;
1937 for (size_t d_i = 0; d_i < d; d_i++) {
1938 size_t d_head = c_head + d_i * hw;
1939 for (size_t h_i = 0; h_i < h; h_i++) {
1940 size_t h_head = d_head + h_i * w;
1941 for (size_t w_i = 0; w_i < w; w_i++) {
1942 size_t dst_i = h_head + w_i;
1943 size_t c1_i = c_i / c0;
1944 size_t c0_i = c_i % c0;
1945 size_t nc_i = n_i;
1946 size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
1947 SetData(size, false, src_i, dst_i, args, result);
1948 }
1949 }
1950 }
1951 }
1952 }
1953 return true;
1954 }
1955
CheckDimsAndDtypes(const FormatArgs & args,size_t * size)1956 bool CheckDimsAndDtypes(const FormatArgs &args, size_t *size) {
1957 if (args.host_shape.size() != kNchwDims) {
1958 MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
1959 return false;
1960 }
1961 *size = abstract::TypeIdSize(args.src_data_type);
1962 if (*size < 1) {
1963 MS_LOG(ERROR) << "Illegal dtype";
1964 return false;
1965 }
1966 return true;
1967 }
1968
NchwFracZTransWithGroups(const FormatArgs & args,void * result,bool to_device,int64_t groups)1969 bool NchwFracZTransWithGroups(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
1970 MS_EXCEPTION_IF_NULL(result);
1971 size_t size = 0;
1972 if (!(CheckDimsAndDtypes(args, &size))) {
1973 MS_LOG(ERROR) << "Illegal input args";
1974 return false;
1975 }
1976 if (groups <= 0) {
1977 MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
1978 }
1979 auto n_dim = args.host_shape[kN];
1980 auto c_dim = args.host_shape[kC];
1981 auto h_dim = args.host_shape[kH];
1982 auto w_dim = args.host_shape[kW];
1983 const size_t d_dim = 1;
1984 size_t group_size = LongToSize(groups);
1985 auto cin_ori = c_dim;
1986 auto cout_ori = n_dim / group_size;
1987 if (cin_ori == 0 || cout_ori == 0) {
1988 MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
1989 return false;
1990 }
1991 size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
1992 if (e_mult == 0) {
1993 MS_LOG(EXCEPTION) << "The value of e_mult should be greater than 0, but got " << e_mult;
1994 }
1995 size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
1996 size_t cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
1997 size_t c1_dim = cin_opt / kCubeSize;
1998 size_t dst_size = to_device ? GetShapeSize(args.device_shape) * size : GetShapeSize(args.host_shape) * size;
1999 if (dst_size == 0) {
2000 return true;
2001 }
2002 auto ret = memset_s(result, dst_size, 0, dst_size);
2003 if (ret != EOK) {
2004 MS_LOG(ERROR) << "memset failed";
2005 return false;
2006 }
2007 for (size_t g = 0; g < group_size; ++g) {
2008 for (size_t d = 0; d < d_dim; ++d) {
2009 for (size_t c = 0; c < c_dim; ++c) {
2010 for (size_t h = 0; h < h_dim; ++h) {
2011 for (size_t w = 0; w < w_dim; ++w) {
2012 for (size_t n = 0; n < cout_ori; ++n) {
2013 size_t e_val = g % e_mult;
2014 size_t dst_ci = e_val * cin_ori + c;
2015 size_t dst_co = e_val * cout_ori + n;
2016 size_t src_co = g * cout_ori + n;
2017 size_t temporary = dst_ci % kCubeSize;
2018 size_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
2019 d * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
2020 (dst_ci / kCubeSize) * h_dim * w_dim * cout_opt * kCubeSize +
2021 h * w_dim * cout_opt * kCubeSize + w * cout_opt * kCubeSize + dst_co * kCubeSize +
2022 temporary;
2023 size_t hst_idx =
2024 src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
2025 if (to_device) {
2026 SetData(size, false, hst_idx, dev_idx, args, result);
2027 } else {
2028 SetData(size, false, dev_idx, hst_idx, args, result);
2029 }
2030 }
2031 }
2032 }
2033 }
2034 }
2035 }
2036 return true;
2037 }
2038
NchwToFracZWithGroups(const FormatArgs & args,void * result,int64_t groups)2039 bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups) {
2040 MS_LOG(DEBUG) << "Trans format from nchw to frac_z with groups=" << groups;
2041 return NchwFracZTransWithGroups(args, result, true, groups);
2042 }
2043
FracZToNchwWithGroups(const FormatArgs & args,void * result,int64_t groups)2044 bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups) {
2045 MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
2046 return NchwFracZTransWithGroups(args, result, false, groups);
2047 }
2048 } // namespace trans
2049 } // namespace mindspore
2050