• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/util.h"
18 
19 #include <utility>
20 #include <map>
21 
22 #include "securec/include/securec.h"
23 #include "utils/convert_utils.h"
24 #include "utils/utils.h"
25 
26 namespace mindspore {
27 namespace transform {
28 using std::make_shared;
29 using std::shared_ptr;
30 using std::string;
31 using std::vector;
32 
33 const size_t kErrorSize = 0;
34 
ConvertIntToList(int64_t data,int size)35 vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
36   vector<int64_t> list{};
37   if (size <= 0) {
38     MS_LOG(WARNING) << "size <= 0";
39     return list;
40   }
41   for (int i = 0; i < size; ++i) {
42     list.push_back(data);
43   }
44   return list;
45 }
46 
47 static std::map<MeDataType, GeDataType> datatype_trans_map = {
48   {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
49   {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE},  {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
50   {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16},     {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
51   {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64},     {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
52   {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16},   {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
53   {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64},   {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}};
54 
ConvertDataType(const MeDataType & type)55 GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
56   MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
57   if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
58     return datatype_trans_map[type];
59   } else {
60     return GeDataType::DT_UNDEFINED;
61   }
62 }
63 
64 static std::map<MeDataType, size_t> datatype_size_map = {
65   {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)},  // 1/2 of float
66   {MeDataType::kNumberTypeFloat64, sizeof(double)},    {MeDataType::kNumberTypeInt8, sizeof(int8_t)},
67   {MeDataType::kNumberTypeInt16, sizeof(int16_t)},     {MeDataType::kNumberTypeInt32, sizeof(int32_t)},
68   {MeDataType::kNumberTypeInt64, sizeof(int64_t)},     {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)},
69   {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)},   {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)},
70   {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)},   {MeDataType::kNumberTypeBool, sizeof(bool)}};
71 
GetDataTypeSize(const MeDataType & type)72 size_t TransformUtil::GetDataTypeSize(const MeDataType &type) {
73   if (datatype_size_map.find(type) != datatype_size_map.end()) {
74     return datatype_size_map[type];
75   } else {
76     MS_LOG(ERROR) << "Illegal tensor data type!";
77     return kErrorSize;
78   }
79 }
80 
ConvertFormat(const string & format)81 GeFormat TransformUtil::ConvertFormat(const string &format) {
82   if (format == kOpFormat_NCHW) {
83     return GeFormat::FORMAT_NCHW;
84   } else if (format == kOpFormat_NDHWC) {
85     return GeFormat::FORMAT_NDHWC;
86   } else if (format == kOpFormat_NCDHW) {
87     return GeFormat::FORMAT_NCDHW;
88   } else if (format == kOpFormat_DHWNC) {
89     return GeFormat::FORMAT_DHWNC;
90   } else if (format == kOpFormat_DHWCN) {
91     return GeFormat::FORMAT_DHWCN;
92   } else if (format == kOpFormat_NC1HWC0) {
93     return GeFormat::FORMAT_NC1HWC0;
94   } else if (format == kOpFormat_NHWC) {
95     return GeFormat::FORMAT_NHWC;
96   } else if (format == kOpFormat_HWCN) {
97     return GeFormat::FORMAT_HWCN;
98   } else if (format == kOpFormat_ND) {
99     return GeFormat::FORMAT_ND;
100   } else {
101     MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
102     return GeFormat::FORMAT_ND;
103   }
104 }
105 
IntegerCastFunc(size_t temp)106 static int64_t IntegerCastFunc(size_t temp) { return static_cast<int64_t>(temp); }
107 
GetGeTensorDesc(const ShapeVector & me_shape,const MeDataType & me_type,const std::string & format)108 std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &me_shape, const MeDataType &me_type,
109                                                              const std::string &format) {
110   // convert me shape to ge shape
111   std::vector<int64_t> ge_shape;
112 
113   if (me_shape.size() == 1) {
114     ge_shape.push_back(static_cast<int64_t>(me_shape[0]));
115   } else {
116     ge_shape.resize(me_shape.size());
117     (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc);
118   }
119 
120   GeShape shape(ge_shape);
121   if (shape.GetDimNum() == 0) {
122     MS_LOG(INFO) << "The dims size of Ge tensor is zero";
123   }
124   // convert me format to ge format
125   GeFormat ge_format = ConvertFormat(format);
126   if (ge_format == GeFormat::FORMAT_ND) {
127     MS_LOG(INFO) << "Set ND data format";
128   }
129   // convert me datatype to ge datatype
130   GeDataType data_type = ConvertDataType(me_type);
131   if (data_type == GeDataType::DT_UNDEFINED) {
132     MS_LOG(ERROR) << "undefined data type :" << me_type;
133     return nullptr;
134   }
135 
136   auto desc = std::make_shared<GeTensorDesc>(shape, ge_format, data_type);
137   if (desc == nullptr) {
138     MS_LOG(ERROR) << "Create GeTensorDesc failed!";
139     return nullptr;
140   }
141   MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size();
142   desc->SetRealDimCnt(SizeToInt(me_shape.size()));
143   return desc;
144 }
145 
146 // if failed, return empty vector.
ConvertInputTensors(const std::vector<MeTensorPtr> & me_tensors,const std::string & format)147 std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
148                                                             const std::string &format) {
149   std::vector<GeTensorPtr> ge_tensors;
150 
151   for (size_t index = 0; index < me_tensors.size(); index++) {
152     MS_EXCEPTION_IF_NULL(me_tensors[index]);
153     MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
154     auto shape = me_tensors[index]->shape();
155     std::string shape_str;
156     for (size_t i = 0; i < shape.size(); i++) {
157       shape_str += std::to_string(shape[i]);
158       shape_str += " ";
159     }
160     MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
161     MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
162 
163     auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
164     if (ge_tensor_ptr != nullptr) {
165       ge_tensors.emplace_back(ge_tensor_ptr);
166     } else {
167       MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
168       ge_tensors.clear();
169       return ge_tensors;
170     }
171   }
172   return ge_tensors;
173 }
174 
ConvertTensor(const MeTensorPtr & tensor,const std::string & format)175 GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) {
176   // get tensor data type size
177   MS_EXCEPTION_IF_NULL(tensor);
178   size_t type_size = GetDataTypeSize(tensor->data_type());
179   if (type_size == kErrorSize) {
180     MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
181     return nullptr;
182   }
183   size_t elements_num = IntToSize(tensor->ElementsNum());
184 
185   // get tensor buff size
186   size_t data_buff_size = elements_num * type_size;
187   if (data_buff_size == 0) {
188     MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
189   }
190   // create ge tensor
191   auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
192   if (desc == nullptr) {
193     MS_LOG(ERROR) << "Failed to get Tensor Desc";
194     return nullptr;
195   }
196   GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc, static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
197   if (tensor_ptr != nullptr) {
198     MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!";
199   }
200   return tensor_ptr;
201 }
202 
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors,const std::vector<ShapeVector> & request_dims)203 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
204                                                          const std::vector<ShapeVector> &request_dims) {
205   std::vector<MeTensorPtr> outputs;
206 
207   for (size_t index = 0; index < ge_tensors.size(); index++) {
208     MeTensorPtr me_tensor_ptr = nullptr;
209     if (index < request_dims.size()) {
210       me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
211     } else {
212       ShapeVector empty_shape;
213       me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
214     }
215 
216     if (me_tensor_ptr != nullptr) {
217       outputs.emplace_back(me_tensor_ptr);
218     } else {
219       MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
220       return outputs;
221     }
222   }
223   return outputs;
224 }
225 
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors)226 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
227   std::vector<MeTensorPtr> outputs;
228 
229   for (size_t index = 0; index < ge_tensors.size(); index++) {
230     MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
231     if (me_tensor_ptr != nullptr) {
232       outputs.emplace_back(me_tensor_ptr);
233     } else {
234       MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
235       return outputs;
236     }
237   }
238   return outputs;
239 }
240 
ConvertGeDataType(const GeDataType & type)241 MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
242   switch (type) {
243     case GeDataType::DT_FLOAT16:
244       return MeDataType::kNumberTypeFloat16;
245     case GeDataType::DT_FLOAT:
246       return MeDataType::kNumberTypeFloat32;
247     case GeDataType::DT_DOUBLE:
248       return MeDataType::kNumberTypeFloat64;
249     case GeDataType::DT_INT64:
250       return MeDataType::kNumberTypeInt64;
251     case GeDataType::DT_INT32:
252       return MeDataType::kNumberTypeInt32;
253     case GeDataType::DT_INT16:
254       return MeDataType::kNumberTypeInt16;
255     case GeDataType::DT_INT8:
256       return MeDataType::kNumberTypeInt8;
257     case GeDataType::DT_BOOL:
258       return MeDataType::kNumberTypeBool;
259     case GeDataType::DT_UINT8:
260       return MeDataType::kNumberTypeUInt8;
261     case GeDataType::DT_UINT16:
262       return MeDataType::kNumberTypeUInt16;
263     case GeDataType::DT_UINT32:
264       return MeDataType::kNumberTypeUInt32;
265     case GeDataType::DT_UINT64:
266       return MeDataType::kNumberTypeUInt64;
267     case GeDataType::DT_UNDEFINED:
268     case GeDataType::DT_DUAL_SUB_UINT8:
269     case GeDataType::DT_DUAL_SUB_INT8:
270     case GeDataType::DT_DUAL:
271       return MeDataType::kTypeUnknown;
272     default:
273       return MeDataType::kTypeUnknown;
274   }
275 }
276 
277 namespace {
IsGeShapeCompatible(const GeShape & ge_shape,const ShapeVector & request_dims)278 bool IsGeShapeCompatible(const GeShape &ge_shape, const ShapeVector &request_dims) {
279   MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
280   MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
281 
282   const int GE_DIMS = 4;
283   std::vector<int64_t> ge_dims = ge_shape.GetDims();
284   if (request_dims.size() > ge_dims.size()) {
285     MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
286     return false;
287   }
288 
289   // convert NHWC to NCHW
290   if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) &&
291       (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) {
292     MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
293     return true;
294   }
295 
296   std::string::size_type i = 0;
297   for (; i < request_dims.size(); i++) {
298     if (ge_dims[i] != request_dims[i]) {
299       MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
300       return false;
301     }
302   }
303 
304   for (; i < ge_dims.size(); i++) {
305     if (ge_dims[i] != 1) {
306       MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
307       return false;
308     }
309   }
310   MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
311   return true;
312 }
313 }  // namespace
314 
ConvertMeShape(const ShapeVector & me_dims)315 GeShape TransformUtil::ConvertMeShape(const ShapeVector &me_dims) {
316   std::vector<int64_t> ge_dims;
317   (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
318   return GeShape(ge_dims);
319 }
320 
ConvertGeShape(const GeShape & ge_shape)321 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
322   ShapeVector me_dims;
323   std::vector<int64_t> ge_dims = ge_shape.GetDims();
324   (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
325   return me_dims;
326 }
327 
ConvertGeShape(const GeShape & ge_shape,const ShapeVector & request_dims)328 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims) {
329   vector<int64_t> ret;
330   if (ge_shape.GetDimNum() == 0) {
331     MS_LOG(DEBUG) << "GeTensor's shape is scalar";
332     return ret;
333   }
334 
335   if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
336     ret = request_dims;
337   } else {
338     MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
339     ret = ConvertGeShape(ge_shape);
340   }
341   return ret;
342 }
343 
GenerateMeTensor(const GeTensorPtr & ge_tensor,const ShapeVector & me_dims,const TypeId & me_type)344 MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims,
345                                             const TypeId &me_type) {
346   MeTensor me_tensor(me_type, me_dims);
347 
348   // Get the writable data pointer of the tensor and cast it to its data type
349   auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c());
350   size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
351   MS_EXCEPTION_IF_NULL(me_data_ptr);
352   MS_EXCEPTION_IF_NULL(ge_tensor);
353   if (me_data_size < ge_tensor->GetSize()) {
354     MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor ["
355                   << ge_tensor->GetSize() << " bytes]";
356     return nullptr;
357   }
358 
359   // Copy or use the writable data pointer of the ME tensor
360   MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
361   if (ge_tensor->GetSize() == 0) {
362     MS_LOG(ERROR) << "GE tensor data size is zero!";
363     return nullptr;
364   }
365 
366   // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB
367   // which is the size limit of memcpy_s
368   memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize());
369 
370   return make_shared<MeTensor>(me_tensor);
371 }
372 
ConvertGeTensor(const GeTensorPtr & ge_tensor)373 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
374   MS_EXCEPTION_IF_NULL(ge_tensor);
375   GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
376   vector<int64_t> me_dims = ConvertGeShape(ge_shape);
377 
378   TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
379   if (type_id == MeDataType::kTypeUnknown) {
380     MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
381                   << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
382     return nullptr;
383   }
384   return GenerateMeTensor(ge_tensor, me_dims, type_id);
385 }
386 
387 // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
ConvertGeTensor(const GeTensorPtr ge_tensor,const ShapeVector & request_dims)388 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const ShapeVector &request_dims) {
389   MS_EXCEPTION_IF_NULL(ge_tensor);
390   GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
391   vector<int64_t> me_dims = ConvertGeShape(ge_shape, request_dims);
392   MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
393   // Create a tensor with wanted data type and shape
394   TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
395   if (type_id == MeDataType::kTypeUnknown) {
396     MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
397                   << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
398     return nullptr;
399   }
400   return GenerateMeTensor(ge_tensor, me_dims, type_id);
401 }
402 
PrintGeTensor(const GeTensorPtr ge_tensor)403 std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
404   std::string ret;
405   if (ge_tensor == nullptr) {
406     MS_LOG(ERROR) << "Input ge tensor is nullptr";
407     return ret;
408   }
409 
410   MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
411   switch (ge_tensor->GetTensorDesc().GetDataType()) {
412     case GeDataType::DT_UINT32:
413       ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
414       break;
415     case GeDataType::DT_FLOAT:
416       ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
417       break;
418     case GeDataType::DT_INT32:
419       ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
420       break;
421     case GeDataType::DT_DOUBLE:
422       ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
423       break;
424     case GeDataType::DT_INT64:
425       ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
426       break;
427     case GeDataType::DT_UINT64:
428       ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
429       break;
430     case GeDataType::DT_INT16:
431       ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
432       break;
433     case GeDataType::DT_UINT16:
434       ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
435       break;
436     case GeDataType::DT_DUAL_SUB_INT8:
437     case GeDataType::DT_INT8:
438       ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
439       break;
440     case GeDataType::DT_UINT8:
441     case GeDataType::DT_DUAL_SUB_UINT8:
442       ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
443       break;
444     case GeDataType::DT_FLOAT16:
445     case GeDataType::DT_BOOL:
446     case GeDataType::DT_UNDEFINED:
447     case GeDataType::DT_DUAL:
448     default:
449       MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
450                     << " ge tensor";
451       break;
452   }
453   return ret;
454 }
455 }  // namespace transform
456 }  // namespace mindspore
457