1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/util.h"
18
19 #include <utility>
20 #include <map>
21
22 #include "securec/include/securec.h"
23 #include "utils/convert_utils.h"
24 #include "utils/utils.h"
25
26 namespace mindspore {
27 namespace transform {
28 using std::make_shared;
29 using std::shared_ptr;
30 using std::string;
31 using std::vector;
32
33 const size_t kErrorSize = 0;
34
ConvertIntToList(int64_t data,int size)35 vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
36 vector<int64_t> list{};
37 if (size <= 0) {
38 MS_LOG(WARNING) << "size <= 0";
39 return list;
40 }
41 for (int i = 0; i < size; ++i) {
42 list.push_back(data);
43 }
44 return list;
45 }
46
47 static std::map<MeDataType, GeDataType> datatype_trans_map = {
48 {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
49 {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
50 {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
51 {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
52 {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
53 {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}};
54
ConvertDataType(const MeDataType & type)55 GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
56 MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
57 if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
58 return datatype_trans_map[type];
59 } else {
60 return GeDataType::DT_UNDEFINED;
61 }
62 }
63
64 static std::map<MeDataType, size_t> datatype_size_map = {
65 {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float
66 {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)},
67 {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)},
68 {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)},
69 {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)},
70 {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}};
71
GetDataTypeSize(const MeDataType & type)72 size_t TransformUtil::GetDataTypeSize(const MeDataType &type) {
73 if (datatype_size_map.find(type) != datatype_size_map.end()) {
74 return datatype_size_map[type];
75 } else {
76 MS_LOG(ERROR) << "Illegal tensor data type!";
77 return kErrorSize;
78 }
79 }
80
ConvertFormat(const string & format)81 GeFormat TransformUtil::ConvertFormat(const string &format) {
82 if (format == kOpFormat_NCHW) {
83 return GeFormat::FORMAT_NCHW;
84 } else if (format == kOpFormat_NDHWC) {
85 return GeFormat::FORMAT_NDHWC;
86 } else if (format == kOpFormat_NCDHW) {
87 return GeFormat::FORMAT_NCDHW;
88 } else if (format == kOpFormat_DHWNC) {
89 return GeFormat::FORMAT_DHWNC;
90 } else if (format == kOpFormat_DHWCN) {
91 return GeFormat::FORMAT_DHWCN;
92 } else if (format == kOpFormat_NC1HWC0) {
93 return GeFormat::FORMAT_NC1HWC0;
94 } else if (format == kOpFormat_NHWC) {
95 return GeFormat::FORMAT_NHWC;
96 } else if (format == kOpFormat_HWCN) {
97 return GeFormat::FORMAT_HWCN;
98 } else if (format == kOpFormat_ND) {
99 return GeFormat::FORMAT_ND;
100 } else {
101 MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
102 return GeFormat::FORMAT_ND;
103 }
104 }
105
IntegerCastFunc(size_t temp)106 static int64_t IntegerCastFunc(size_t temp) { return static_cast<int64_t>(temp); }
107
GetGeTensorDesc(const ShapeVector & me_shape,const MeDataType & me_type,const std::string & format)108 std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &me_shape, const MeDataType &me_type,
109 const std::string &format) {
110 // convert me shape to ge shape
111 std::vector<int64_t> ge_shape;
112
113 if (me_shape.size() == 1) {
114 ge_shape.push_back(static_cast<int64_t>(me_shape[0]));
115 } else {
116 ge_shape.resize(me_shape.size());
117 (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc);
118 }
119
120 GeShape shape(ge_shape);
121 if (shape.GetDimNum() == 0) {
122 MS_LOG(INFO) << "The dims size of Ge tensor is zero";
123 }
124 // convert me format to ge format
125 GeFormat ge_format = ConvertFormat(format);
126 if (ge_format == GeFormat::FORMAT_ND) {
127 MS_LOG(INFO) << "Set ND data format";
128 }
129 // convert me datatype to ge datatype
130 GeDataType data_type = ConvertDataType(me_type);
131 if (data_type == GeDataType::DT_UNDEFINED) {
132 MS_LOG(ERROR) << "undefined data type :" << me_type;
133 return nullptr;
134 }
135
136 auto desc = std::make_shared<GeTensorDesc>(shape, ge_format, data_type);
137 if (desc == nullptr) {
138 MS_LOG(ERROR) << "Create GeTensorDesc failed!";
139 return nullptr;
140 }
141 MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size();
142 desc->SetRealDimCnt(SizeToInt(me_shape.size()));
143 return desc;
144 }
145
146 // if failed, return empty vector.
ConvertInputTensors(const std::vector<MeTensorPtr> & me_tensors,const std::string & format)147 std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
148 const std::string &format) {
149 std::vector<GeTensorPtr> ge_tensors;
150
151 for (size_t index = 0; index < me_tensors.size(); index++) {
152 MS_EXCEPTION_IF_NULL(me_tensors[index]);
153 MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
154 auto shape = me_tensors[index]->shape();
155 std::string shape_str;
156 for (size_t i = 0; i < shape.size(); i++) {
157 shape_str += std::to_string(shape[i]);
158 shape_str += " ";
159 }
160 MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
161 MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
162
163 auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
164 if (ge_tensor_ptr != nullptr) {
165 ge_tensors.emplace_back(ge_tensor_ptr);
166 } else {
167 MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
168 ge_tensors.clear();
169 return ge_tensors;
170 }
171 }
172 return ge_tensors;
173 }
174
ConvertTensor(const MeTensorPtr & tensor,const std::string & format)175 GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) {
176 // get tensor data type size
177 MS_EXCEPTION_IF_NULL(tensor);
178 size_t type_size = GetDataTypeSize(tensor->data_type());
179 if (type_size == kErrorSize) {
180 MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
181 return nullptr;
182 }
183 size_t elements_num = IntToSize(tensor->ElementsNum());
184
185 // get tensor buff size
186 size_t data_buff_size = elements_num * type_size;
187 if (data_buff_size == 0) {
188 MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
189 }
190 // create ge tensor
191 auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
192 if (desc == nullptr) {
193 MS_LOG(ERROR) << "Failed to get Tensor Desc";
194 return nullptr;
195 }
196 GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc, static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
197 if (tensor_ptr != nullptr) {
198 MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!";
199 }
200 return tensor_ptr;
201 }
202
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors,const std::vector<ShapeVector> & request_dims)203 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
204 const std::vector<ShapeVector> &request_dims) {
205 std::vector<MeTensorPtr> outputs;
206
207 for (size_t index = 0; index < ge_tensors.size(); index++) {
208 MeTensorPtr me_tensor_ptr = nullptr;
209 if (index < request_dims.size()) {
210 me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
211 } else {
212 ShapeVector empty_shape;
213 me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
214 }
215
216 if (me_tensor_ptr != nullptr) {
217 outputs.emplace_back(me_tensor_ptr);
218 } else {
219 MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
220 return outputs;
221 }
222 }
223 return outputs;
224 }
225
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors)226 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
227 std::vector<MeTensorPtr> outputs;
228
229 for (size_t index = 0; index < ge_tensors.size(); index++) {
230 MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
231 if (me_tensor_ptr != nullptr) {
232 outputs.emplace_back(me_tensor_ptr);
233 } else {
234 MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
235 return outputs;
236 }
237 }
238 return outputs;
239 }
240
ConvertGeDataType(const GeDataType & type)241 MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
242 switch (type) {
243 case GeDataType::DT_FLOAT16:
244 return MeDataType::kNumberTypeFloat16;
245 case GeDataType::DT_FLOAT:
246 return MeDataType::kNumberTypeFloat32;
247 case GeDataType::DT_DOUBLE:
248 return MeDataType::kNumberTypeFloat64;
249 case GeDataType::DT_INT64:
250 return MeDataType::kNumberTypeInt64;
251 case GeDataType::DT_INT32:
252 return MeDataType::kNumberTypeInt32;
253 case GeDataType::DT_INT16:
254 return MeDataType::kNumberTypeInt16;
255 case GeDataType::DT_INT8:
256 return MeDataType::kNumberTypeInt8;
257 case GeDataType::DT_BOOL:
258 return MeDataType::kNumberTypeBool;
259 case GeDataType::DT_UINT8:
260 return MeDataType::kNumberTypeUInt8;
261 case GeDataType::DT_UINT16:
262 return MeDataType::kNumberTypeUInt16;
263 case GeDataType::DT_UINT32:
264 return MeDataType::kNumberTypeUInt32;
265 case GeDataType::DT_UINT64:
266 return MeDataType::kNumberTypeUInt64;
267 case GeDataType::DT_UNDEFINED:
268 case GeDataType::DT_DUAL_SUB_UINT8:
269 case GeDataType::DT_DUAL_SUB_INT8:
270 case GeDataType::DT_DUAL:
271 return MeDataType::kTypeUnknown;
272 default:
273 return MeDataType::kTypeUnknown;
274 }
275 }
276
277 namespace {
IsGeShapeCompatible(const GeShape & ge_shape,const ShapeVector & request_dims)278 bool IsGeShapeCompatible(const GeShape &ge_shape, const ShapeVector &request_dims) {
279 MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
280 MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
281
282 const int GE_DIMS = 4;
283 std::vector<int64_t> ge_dims = ge_shape.GetDims();
284 if (request_dims.size() > ge_dims.size()) {
285 MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
286 return false;
287 }
288
289 // convert NHWC to NCHW
290 if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) &&
291 (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) {
292 MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
293 return true;
294 }
295
296 std::string::size_type i = 0;
297 for (; i < request_dims.size(); i++) {
298 if (ge_dims[i] != request_dims[i]) {
299 MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
300 return false;
301 }
302 }
303
304 for (; i < ge_dims.size(); i++) {
305 if (ge_dims[i] != 1) {
306 MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
307 return false;
308 }
309 }
310 MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
311 return true;
312 }
313 } // namespace
314
ConvertMeShape(const ShapeVector & me_dims)315 GeShape TransformUtil::ConvertMeShape(const ShapeVector &me_dims) {
316 std::vector<int64_t> ge_dims;
317 (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
318 return GeShape(ge_dims);
319 }
320
ConvertGeShape(const GeShape & ge_shape)321 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
322 ShapeVector me_dims;
323 std::vector<int64_t> ge_dims = ge_shape.GetDims();
324 (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
325 return me_dims;
326 }
327
ConvertGeShape(const GeShape & ge_shape,const ShapeVector & request_dims)328 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims) {
329 vector<int64_t> ret;
330 if (ge_shape.GetDimNum() == 0) {
331 MS_LOG(DEBUG) << "GeTensor's shape is scalar";
332 return ret;
333 }
334
335 if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
336 ret = request_dims;
337 } else {
338 MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
339 ret = ConvertGeShape(ge_shape);
340 }
341 return ret;
342 }
343
GenerateMeTensor(const GeTensorPtr & ge_tensor,const ShapeVector & me_dims,const TypeId & me_type)344 MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims,
345 const TypeId &me_type) {
346 MeTensor me_tensor(me_type, me_dims);
347
348 // Get the writable data pointer of the tensor and cast it to its data type
349 auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c());
350 size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
351 MS_EXCEPTION_IF_NULL(me_data_ptr);
352 MS_EXCEPTION_IF_NULL(ge_tensor);
353 if (me_data_size < ge_tensor->GetSize()) {
354 MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor ["
355 << ge_tensor->GetSize() << " bytes]";
356 return nullptr;
357 }
358
359 // Copy or use the writable data pointer of the ME tensor
360 MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
361 if (ge_tensor->GetSize() == 0) {
362 MS_LOG(ERROR) << "GE tensor data size is zero!";
363 return nullptr;
364 }
365
366 // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB
367 // which is the size limit of memcpy_s
368 memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize());
369
370 return make_shared<MeTensor>(me_tensor);
371 }
372
ConvertGeTensor(const GeTensorPtr & ge_tensor)373 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
374 MS_EXCEPTION_IF_NULL(ge_tensor);
375 GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
376 vector<int64_t> me_dims = ConvertGeShape(ge_shape);
377
378 TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
379 if (type_id == MeDataType::kTypeUnknown) {
380 MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
381 << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
382 return nullptr;
383 }
384 return GenerateMeTensor(ge_tensor, me_dims, type_id);
385 }
386
387 // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
ConvertGeTensor(const GeTensorPtr ge_tensor,const ShapeVector & request_dims)388 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const ShapeVector &request_dims) {
389 MS_EXCEPTION_IF_NULL(ge_tensor);
390 GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
391 vector<int64_t> me_dims = ConvertGeShape(ge_shape, request_dims);
392 MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
393 // Create a tensor with wanted data type and shape
394 TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
395 if (type_id == MeDataType::kTypeUnknown) {
396 MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
397 << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
398 return nullptr;
399 }
400 return GenerateMeTensor(ge_tensor, me_dims, type_id);
401 }
402
PrintGeTensor(const GeTensorPtr ge_tensor)403 std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
404 std::string ret;
405 if (ge_tensor == nullptr) {
406 MS_LOG(ERROR) << "Input ge tensor is nullptr";
407 return ret;
408 }
409
410 MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
411 switch (ge_tensor->GetTensorDesc().GetDataType()) {
412 case GeDataType::DT_UINT32:
413 ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
414 break;
415 case GeDataType::DT_FLOAT:
416 ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
417 break;
418 case GeDataType::DT_INT32:
419 ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
420 break;
421 case GeDataType::DT_DOUBLE:
422 ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
423 break;
424 case GeDataType::DT_INT64:
425 ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
426 break;
427 case GeDataType::DT_UINT64:
428 ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
429 break;
430 case GeDataType::DT_INT16:
431 ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
432 break;
433 case GeDataType::DT_UINT16:
434 ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
435 break;
436 case GeDataType::DT_DUAL_SUB_INT8:
437 case GeDataType::DT_INT8:
438 ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
439 break;
440 case GeDataType::DT_UINT8:
441 case GeDataType::DT_DUAL_SUB_UINT8:
442 ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
443 break;
444 case GeDataType::DT_FLOAT16:
445 case GeDataType::DT_BOOL:
446 case GeDataType::DT_UNDEFINED:
447 case GeDataType::DT_DUAL:
448 default:
449 MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
450 << " ge tensor";
451 break;
452 }
453 return ret;
454 }
455 } // namespace transform
456 } // namespace mindspore
457