• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/transform_util.h"
18 #include <utility>
19 #include <map>
20 #include <algorithm>
21 #include <complex>
22 
23 #include "include/common/utils/convert_utils.h"
24 #include "include/common/utils/utils.h"
25 #include "utils/shape_utils.h"
26 #include "transform/graph_ir/op_adapter_util.h"
27 
28 #ifndef ENABLE_LITE_ACL
29 #include "include/common/utils/python_adapter.h"
30 #endif
31 
32 namespace mindspore {
33 namespace transform {
34 using std::make_shared;
35 using std::shared_ptr;
36 using std::string;
37 using std::vector;
38 
39 const size_t kErrorSize = 0;
40 const size_t kIdx0 = 0;
41 const size_t kIdx1 = 1;
42 const size_t kIdx2 = 2;
43 const size_t kIdx3 = 3;
44 
45 namespace {
46 class MsTensorRel {
47  public:
MsTensorRel(const MeTensorPtr & tensor)48   explicit MsTensorRel(const MeTensorPtr &tensor) : tensor_(tensor) {}
49   ~MsTensorRel() = default;
Rel() const50   void Rel() const { tensor_ = nullptr; }
51 
52  private:
53   mutable MeTensorPtr tensor_;
54 };
55 }  // namespace
56 
57 class TensorRefData : public tensor::TensorData {
58  public:
TensorRefData(void * data,ssize_t data_size,ssize_t itemsize,ssize_t ndim)59   TensorRefData(void *data, ssize_t data_size, ssize_t itemsize, ssize_t ndim)
60       : data_(data), data_size_(data_size), itemsize_(itemsize), ndim_(ndim) {}
61 
62   ~TensorRefData() override = default;
63 
64   // Total number of elements.
size() const65   ssize_t size() const override { return data_size_; }
66 
67   // Byte size of a single element.
itemsize() const68   ssize_t itemsize() const override { return itemsize_; }
69 
70   // Total number of bytes.
nbytes() const71   ssize_t nbytes() const override { return size() * itemsize(); }
72 
73   // Number of dimensions.
ndim() const74   ssize_t ndim() const override { return ndim_; }
75 
data()76   void *data() override { return data_; }
const_data() const77   const void *const_data() const override { return data_; }
78 
is_sub_data() const79   bool is_sub_data() const override { return false; }
has_sub_data() const80   bool has_sub_data() const override { return false; }
81 
ToString(TypeId type,const ShapeVector & shape,bool use_comma) const82   std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override { return ""; }
83 
84  protected:
85   void *data_ = nullptr;
86   ssize_t data_size_ = 0;
87   ssize_t itemsize_ = 0;
88   ssize_t ndim_ = 0;
89 };
90 
ConvertIntToList(int64_t data,int size)91 vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
92   vector<int64_t> list{};
93   if (size <= 0) {
94     MS_LOG(WARNING) << "size <= 0";
95     return list;
96   }
97   for (int i = 0; i < size; ++i) {
98     list.emplace_back(data);
99   }
100   return list;
101 }
102 
103 static std::map<MeDataType, GeDataType> datatype_trans_map = {
104   {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16},
105   {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
106   {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE},
107   {MeDataType::kNumberTypeBFloat16, GeDataType::DT_BF16},
108   {MeDataType::kNumberTypeInt4, GeDataType::DT_INT4},
109   {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
110   {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16},
111   {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
112   {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64},
113   {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
114   {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16},
115   {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
116   {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64},
117   {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL},
118   {MeDataType::kObjectTypeString, GeDataType::DT_STRING},
119   {MeDataType::kNumberTypeFloat, GeDataType::DT_FLOAT},
120   {MeDataType::kNumberTypeComplex64, GeDataType::DT_COMPLEX64},
121   {MeDataType::kNumberTypeComplex128, GeDataType::DT_COMPLEX128}};
122 
ConvertDataType(const MeDataType & type)123 GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
124   MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
125   if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
126     return datatype_trans_map[type];
127   } else {
128     return GeDataType::DT_UNDEFINED;
129   }
130 }
131 
ConvertFormat(const string & format,const size_t shape_size)132 GeFormat TransformUtil::ConvertFormat(const string &format, const size_t shape_size) {
133   static constexpr size_t k4dSize = 4;
134   static const std::map<std::string, GeFormat> format_map = {
135     {kOpFormat_DEFAULT, GeFormat::FORMAT_NCHW},
136     {kOpFormat_NC1KHKWHWC0, GeFormat::FORMAT_NC1KHKWHWC0},
137     {kOpFormat_ND, GeFormat::FORMAT_ND},
138     {kOpFormat_NCHW, GeFormat::FORMAT_NCHW},
139     {kOpFormat_NHWC, GeFormat::FORMAT_NHWC},
140     {kOpFormat_HWCN, GeFormat::FORMAT_HWCN},
141     {kOpFormat_NC1HWC0, GeFormat::FORMAT_NC1HWC0},
142     {kOpFormat_FRAC_Z, GeFormat::FORMAT_FRACTAL_Z},
143     {kOpFormat_FRAC_NZ, GeFormat::FORMAT_FRACTAL_NZ},
144     {kOpFormat_C1HWNCoC0, GeFormat::FORMAT_C1HWNCoC0},
145     {kOpFormat_NC1HWC0_C04, GeFormat::FORMAT_NC1HWC0_C04},
146     {kOpFormat_FRACTAL_Z_C04, GeFormat::FORMAT_FRACTAL_Z_C04},
147     {kOpFormat_NDHWC, GeFormat::FORMAT_NDHWC},
148     {kOpFormat_NCDHW, GeFormat::FORMAT_NCDHW},
149     {kOpFormat_DHWNC, GeFormat::FORMAT_DHWNC},
150     {kOpFormat_DHWCN, GeFormat::FORMAT_DHWCN},
151     {kOpFormat_NDC1HWC0, GeFormat::FORMAT_NDC1HWC0},
152     {kOpFormat_FRACTAL_Z_3D, GeFormat::FORMAT_FRACTAL_Z_3D},
153     {kOpFormat_FRACTAL_ZN_LSTM, GeFormat::FORMAT_FRACTAL_ZN_LSTM},
154     {kOpFormat_ND_RNN_BIAS, GeFormat::FORMAT_ND_RNN_BIAS},
155     {kOpFormat_FRACTAL_ZN_RNN, GeFormat::FORMAT_FRACTAL_ZN_RNN}};
156   if (format == kOpFormat_DEFAULT) {
157     return shape_size == k4dSize ? GeFormat::FORMAT_NCHW : GeFormat::FORMAT_ND;
158   }
159   auto iter = format_map.find(format);
160   if (iter == format_map.end()) {
161     MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
162     return GeFormat::FORMAT_ND;
163   }
164   return iter->second;
165 }
166 
GetGeTensorDesc(const ShapeVector & ori_shape,const MeDataType & me_type,const std::string & ori_format,const ShapeVector & dev_shape,const std::string & dev_format)167 std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &ori_shape, const MeDataType &me_type,
168                                                              const std::string &ori_format,
169                                                              const ShapeVector &dev_shape,
170                                                              const std::string &dev_format) {
171   // convert me shape to ge shape
172   GeShape ori_ge_shape(ori_shape);
173   if (ori_ge_shape.GetDimNum() == 0) {
174     MS_LOG(DEBUG) << "The dims size of Ge tensor is zero";
175   }
176   // convert me format to ge format
177   GeFormat ori_ge_format = ConvertFormat(ori_format, ori_shape.size());
178   if (ori_ge_format == GeFormat::FORMAT_ND) {
179     MS_LOG(DEBUG) << "Set ND data format";
180   }
181   // convert me datatype to ge datatype
182   GeDataType data_type = ConvertDataType(me_type);
183   if (data_type == GeDataType::DT_UNDEFINED) {
184     MS_LOG(WARNING) << "undefined data type :" << me_type;
185     return nullptr;
186   }
187   auto desc = std::make_shared<GeTensorDesc>();
188   if (desc == nullptr) {
189     MS_LOG(ERROR) << "Create GeTensorDesc failed!";
190     return nullptr;
191   }
192   // set ori shape and format.
193   // note: if ori_shape and ori_format have been set. the set_shape and set_format will run as device info, otherwise
194   // the set_shape and set_format will run as host info.
195   if (!std::any_of(ori_shape.cbegin(), ori_shape.cend(), [](const auto &dim) { return dim < 0; })) {
196     desc->SetOriginShape(ori_ge_shape);
197     desc->SetOriginFormat(ori_ge_format);
198   }
199   desc->SetDataType(data_type);
200 
201   // set device shape and format, if value is empty, use ori shape and format replace.
202   auto dev_ge_shape = dev_shape.empty() ? ori_ge_shape : GeShape(dev_shape);
203   GeFormat dev_ge_format = dev_format.empty() ? ori_ge_format : ConvertFormat(dev_format, dev_ge_shape.GetDimNum());
204   if (me_type == MeDataType::kNumberTypeInt4) {
205     int64_t last_dim = dev_ge_shape.GetDimNum() - 1;
206     dev_ge_shape.SetDim(last_dim, dev_ge_shape.GetDim(last_dim) * 2);
207   }
208   desc->SetShape(dev_ge_shape);
209   desc->SetFormat(dev_ge_format);
210 
211   MS_LOG(DEBUG) << "SetRealDimCnt is :" << ori_shape.size();
212   desc->SetRealDimCnt(SizeToInt(ori_shape.size()));
213   return desc;
214 }
215 
216 // if failed, return empty vector.
ConvertInputTensors(const std::vector<MeTensorPtr> & me_tensors,const std::string & format)217 std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
218                                                             const std::string &format) {
219   std::vector<GeTensorPtr> ge_tensors;
220 
221   for (size_t index = 0; index < me_tensors.size(); index++) {
222     MS_EXCEPTION_IF_NULL(me_tensors[index]);
223     MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
224     auto shape = me_tensors[index]->shape();
225     std::string shape_str;
226     for (size_t i = 0; i < shape.size(); i++) {
227       shape_str += std::to_string(shape[i]);
228       shape_str += " ";
229     }
230     MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
231     MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
232 
233     auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
234     if (ge_tensor_ptr != nullptr) {
235       (void)ge_tensors.emplace_back(ge_tensor_ptr);
236     } else {
237       MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
238       ge_tensors.clear();
239       return ge_tensors;
240     }
241   }
242   return ge_tensors;
243 }
244 
245 #ifndef ENABLE_LITE_ACL
ConvertStringTensor(const MeTensorPtr & tensor,const std::string & format)246 GeTensorPtr ConvertStringTensor(const MeTensorPtr &tensor, const std::string &format) {
247   auto desc = TransformUtil::GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
248   if (desc == nullptr) {
249     MS_LOG(ERROR) << "Failed to get Tensor Desc";
250     return nullptr;
251   }
252   GeTensorPtr tensor_ptr = nullptr;
253   auto data_buff_size = tensor->data().nbytes();
254   py::gil_scoped_acquire gil;
255   auto py_array = python_adapter::PyAdapterCallback::TensorToNumpy(*tensor);
256   auto buf = py_array.request();
257   auto data_ptr = static_cast<char *>(tensor->data().data());
258   size_t single_char_offset = 4;
259 
260   if (buf.format.back() == 'w') {
261     auto max_length = buf.format.substr(0, buf.format.length() - 1);
262     int64_t max_length_long = 0;
263     try {
264       max_length_long = std::stol(max_length);
265     } catch (const std::exception &e) {
266       MS_LOG(EXCEPTION) << "Invalid argument:" << e.what() << " when parse " << max_length;
267     }
268     auto string_max_length = LongToSize(max_length_long);
269     if (string_max_length == 0) {
270       MS_LOG(ERROR) << "Failed to get Tensor Desc. Please check string length";
271       return nullptr;
272     }
273     size_t elements_num = (data_buff_size / single_char_offset) / string_max_length;
274     std::vector<std::string> string_vector;
275     char *string_element = new char[string_max_length];
276     size_t string_length = 0;
277     for (size_t i = 0; i < elements_num; i++) {
278       (void)std::fill_n(string_element, string_max_length, '\0');
279       for (size_t j = 0; j < string_max_length; j++) {
280         char char_element = data_ptr[i * string_max_length * single_char_offset + single_char_offset * j];
281         if (static_cast<int>(char_element) == 0) {
282           break;
283         } else {
284           string_element[j] = char_element;
285           string_length += 1;
286         }
287       }
288       std::string string_to_add(string_element, string_length);
289       (void)string_vector.emplace_back(string_to_add);
290     }
291     delete[] string_element;
292     string_element = nullptr;
293     tensor_ptr = make_shared<GeTensor>(*desc);
294     (void)tensor_ptr->SetData(string_vector);
295   } else {
296     int64_t length_long = 0;
297     try {
298       length_long = std::stol(buf.format.substr(0, buf.format.length() - 1));
299     } catch (const std::exception &e) {
300       MS_LOG(EXCEPTION) << "Invalid argument:" << e.what() << " when parse "
301                         << buf.format.substr(0, buf.format.length() - 1);
302     }
303     auto string_length = LongToSize(length_long);
304     if (string_length == 0) {
305       MS_LOG(ERROR) << "Failed to get Tensor Desc. Please check string length";
306       return nullptr;
307     }
308     char *string_element = new char[string_length];
309     for (size_t i = 0; i < string_length; i++) {
310       string_element[i] = data_ptr[i];
311     }
312     std::string string_to_add(string_element, string_length);
313     tensor_ptr = make_shared<GeTensor>(*desc);
314     (void)tensor_ptr->SetData(string_to_add);
315     delete[] string_element;
316     string_element = nullptr;
317   }
318   return tensor_ptr;
319 }
320 #endif
321 
ConvertTensor(const MeTensorPtr & tensor,const std::string & format,bool copy)322 GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format, bool copy) {
323   // get tensor data type size
324   MS_EXCEPTION_IF_NULL(tensor);
325   auto me_data_type = tensor->data_type();
326 #ifndef ENABLE_LITE_ACL
327   if (me_data_type == mindspore::kObjectTypeString) {
328     return ConvertStringTensor(tensor, format);
329   }
330 #endif
331   size_t type_size = GetDataTypeSize(me_data_type);
332   if (type_size == kErrorSize) {
333     MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
334     return nullptr;
335   }
336 
337   // get tensor buff size
338   size_t data_buff_size = tensor->Size();
339   if (data_buff_size == 0) {
340     MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
341   }
342   // create ge tensor
343   auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
344   if (desc == nullptr) {
345     MS_LOG(ERROR) << "Failed to get Tensor Desc";
346     return nullptr;
347   }
348   GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc);
349   if (tensor_ptr == nullptr) {
350     MS_LOG(ERROR) << "Failed to convert Me Tensor to Ge Tensor!";
351     return nullptr;
352   }
353   if (copy) {
354     auto ret = tensor_ptr->SetData(static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
355     if (ret != ge::GRAPH_SUCCESS) {
356       MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(const uint8_t*, size), data size " << data_buff_size;
357       return nullptr;
358     }
359   } else {
360     MsTensorRel rel(tensor);
361     auto ret = tensor_ptr->SetData(static_cast<uint8_t *>(tensor->data_c()), data_buff_size,
362                                    [rel](uint8_t *) -> void { rel.Rel(); });
363     if (ret != ge::GRAPH_SUCCESS) {
364       MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << data_buff_size;
365       return nullptr;
366     }
367   }
368   MS_LOG(DEBUG) << "Convert Me Tensor to Ge Tensor success!";
369   return tensor_ptr;
370 }
371 
ConvertScalar(const ValuePtr & val)372 GeTensorPtr TransformUtil::ConvertScalar(const ValuePtr &val) {
373   auto ge_tensor = ConvertAnyUtil(val, AnyTraits<ValueAny>());
374   return make_shared<GeTensor>(ge_tensor);
375 }
376 
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors,const std::vector<ShapeVector> & request_dims)377 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
378                                                          const std::vector<ShapeVector> &request_dims) {
379   std::vector<MeTensorPtr> outputs;
380 
381   for (size_t index = 0; index < ge_tensors.size(); index++) {
382     MeTensorPtr me_tensor_ptr = nullptr;
383     if (index < request_dims.size()) {
384       me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
385     } else {
386       ShapeVector empty_shape;
387       me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
388     }
389 
390     if (me_tensor_ptr != nullptr) {
391       (void)outputs.emplace_back(me_tensor_ptr);
392     } else {
393       MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
394       return outputs;
395     }
396   }
397   return outputs;
398 }
399 
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors)400 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
401   std::vector<MeTensorPtr> outputs;
402 
403   for (size_t index = 0; index < ge_tensors.size(); index++) {
404     MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
405     if (me_tensor_ptr != nullptr) {
406       (void)outputs.emplace_back(me_tensor_ptr);
407     } else {
408       MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
409       return outputs;
410     }
411   }
412   return outputs;
413 }
414 
ConvertGeDataType(const GeDataType & type)415 MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
416   switch (type) {
417     case GeDataType::DT_FLOAT16:
418       return MeDataType::kNumberTypeFloat16;
419     case GeDataType::DT_BF16:
420       return MeDataType::kNumberTypeBFloat16;
421     case GeDataType::DT_FLOAT:
422       return MeDataType::kNumberTypeFloat32;
423     case GeDataType::DT_DOUBLE:
424       return MeDataType::kNumberTypeFloat64;
425     case GeDataType::DT_INT64:
426       return MeDataType::kNumberTypeInt64;
427     case GeDataType::DT_INT32:
428       return MeDataType::kNumberTypeInt32;
429     case GeDataType::DT_INT16:
430       return MeDataType::kNumberTypeInt16;
431     case GeDataType::DT_INT8:
432       return MeDataType::kNumberTypeInt8;
433     case GeDataType::DT_BOOL:
434       return MeDataType::kNumberTypeBool;
435     case GeDataType::DT_UINT8:
436       return MeDataType::kNumberTypeUInt8;
437     case GeDataType::DT_UINT16:
438       return MeDataType::kNumberTypeUInt16;
439     case GeDataType::DT_UINT32:
440       return MeDataType::kNumberTypeUInt32;
441     case GeDataType::DT_UINT64:
442       return MeDataType::kNumberTypeUInt64;
443     case GeDataType::DT_UNDEFINED:
444     case GeDataType::DT_DUAL_SUB_UINT8:
445     case GeDataType::DT_DUAL_SUB_INT8:
446     case GeDataType::DT_DUAL:
447       return MeDataType::kTypeUnknown;
448     default:
449       return MeDataType::kTypeUnknown;
450   }
451 }
452 
453 namespace {
IsGeShapeCompatible(const GeShape & ge_shape,const ShapeVector & request_dims)454 bool IsGeShapeCompatible(const GeShape &ge_shape, const ShapeVector &request_dims) {
455   MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
456   MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
457 
458   const int GE_DIMS = 4;
459   std::vector<int64_t> ge_dims = ge_shape.GetDims();
460   if (request_dims.size() > ge_dims.size()) {
461     MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
462     return false;
463   }
464 
465   // convert NHWC to NCHW
466   if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[kIdx0] == ge_dims[kIdx1]) &&
467       (ge_dims[kIdx0] == 1) && (ge_dims[kIdx2] == 1) && (ge_dims[kIdx3] == 1)) {
468     MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
469     return true;
470   }
471 
472   std::string::size_type i = 0;
473   for (; i < request_dims.size(); i++) {
474     if (ge_dims[i] != request_dims[i]) {
475       MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
476       return false;
477     }
478   }
479 
480   for (; i < ge_dims.size(); i++) {
481     if (ge_dims[i] != 1) {
482       MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
483       return false;
484     }
485   }
486   MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
487   return true;
488 }
489 }  // namespace
490 
ConvertMeShape(const ShapeVector & me_dims)491 GeShape TransformUtil::ConvertMeShape(const ShapeVector &me_dims) {
492   std::vector<int64_t> ge_dims;
493   (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
494   return GeShape(ge_dims);
495 }
496 
ConvertGeShape(const GeShape & ge_shape)497 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
498   ShapeVector me_dims;
499   std::vector<int64_t> ge_dims = ge_shape.GetDims();
500   (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
501   return me_dims;
502 }
503 
ConvertGeShape(const GeShape & ge_shape,const ShapeVector & request_dims)504 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims) {
505   vector<int64_t> ret;
506   if (ge_shape.GetDimNum() == 0) {
507     MS_LOG(DEBUG) << "GeTensor's shape is scalar";
508     return ret;
509   }
510 
511   if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
512     ret = request_dims;
513   } else {
514     MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
515     ret = ConvertGeShape(ge_shape);
516   }
517   return ret;
518 }
519 
GenerateMeTensor(const GeTensorPtr & ge_tensor,const ShapeVector & me_dims,const TypeId & me_type,bool ref_mem)520 MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims,
521                                             const TypeId &me_type, bool ref_mem) {
522   MS_EXCEPTION_IF_NULL(ge_tensor);
523   MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
524   if (ge_tensor->GetSize() == 0) {
525     MS_LOG(ERROR) << "GE tensor data size is zero!";
526     return nullptr;
527   }
528 
529   if (ref_mem) {
530     void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(ge_tensor->GetData()));
531     ssize_t data_size = static_cast<ssize_t>(SizeOf(me_dims));
532     ssize_t itemsize = MeTensor(me_type, ShapeVector()).data().itemsize();
533     ssize_t ndim = static_cast<ssize_t>(me_dims.size());
534     auto ref_data = std::make_shared<TensorRefData>(data, data_size, itemsize, ndim);
535     return make_shared<MeTensor>(me_type, me_dims, ref_data);
536   } else {
537     MeTensor me_tensor(me_type, me_dims);
538 
539     // Get the writable data pointer of the tensor and cast it to its data type.
540     auto me_data_ptr = me_tensor.data_c();
541     size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
542     MS_EXCEPTION_IF_NULL(me_data_ptr);
543     size_t length = ge_tensor->GetSize();
544     if (me_data_size < length) {
545       MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor [" << length
546                     << " bytes]";
547       return nullptr;
548     }
549 
550     if (length < SECUREC_MEM_MAX_LEN) {
551       int ret_code = memcpy_s(me_data_ptr, length, ge_tensor->GetData(), length);
552       if (ret_code != EOK) {
553         MS_LOG(ERROR) << "Memcpy_s from ge_tensor to me_tensor failed.";
554         return nullptr;
555       }
556     } else {
557       (void)memcpy(me_data_ptr, ge_tensor->GetData(), length);
558     }
559 
560     return make_shared<MeTensor>(me_tensor);
561   }
562 }
563 
ConvertGeTensor(const GeTensorPtr & ge_tensor)564 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
565   MS_EXCEPTION_IF_NULL(ge_tensor);
566   GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
567   vector<int64_t> me_dims = ConvertGeShape(ge_shape);
568 
569   TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
570   if (type_id == MeDataType::kTypeUnknown) {
571     MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
572                   << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
573     return nullptr;
574   }
575   return GenerateMeTensor(ge_tensor, me_dims, type_id);
576 }
577 
ConvertGeTensor(const GeTensorPtr & ge_tensor,const TypeId & me_type)578 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor, const TypeId &me_type) {
579   MS_EXCEPTION_IF_NULL(ge_tensor);
580   GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
581   vector<int64_t> me_dims = ConvertGeShape(ge_shape);
582 
583   if (me_type == MeDataType::kTypeUnknown) {
584     MS_LOG(ERROR) << "Unsupported data type: " << static_cast<int>(me_type);
585     return nullptr;
586   }
587   return GenerateMeTensor(ge_tensor, me_dims, me_type);
588 }
589 
590 // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
ConvertGeTensor(const GeTensorPtr ge_tensor,const ShapeVector & request_dims,bool ref_mem)591 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const ShapeVector &request_dims, bool ref_mem) {
592   MS_EXCEPTION_IF_NULL(ge_tensor);
593   GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
594   vector<int64_t> me_dims = ConvertGeShape(ge_shape, request_dims);
595   MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
596   // Create a tensor with wanted data type and shape
597   TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
598   if (type_id == MeDataType::kTypeUnknown) {
599     MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
600                   << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
601     return nullptr;
602   }
603   return GenerateMeTensor(ge_tensor, me_dims, type_id, ref_mem);
604 }
605 
PrintGeTensor(const GeTensorPtr ge_tensor)606 std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
607   std::string ret;
608   if (ge_tensor == nullptr) {
609     MS_LOG(ERROR) << "Input ge tensor is nullptr";
610     return ret;
611   }
612 
613   MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
614   switch (static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())) {
615     case GeDataType::DT_UINT32:
616       ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
617       break;
618     case GeDataType::DT_FLOAT:
619       ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
620       break;
621     case GeDataType::DT_INT32:
622       ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
623       break;
624     case GeDataType::DT_DOUBLE:
625       ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
626       break;
627     case GeDataType::DT_INT64:
628       ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
629       break;
630     case GeDataType::DT_UINT64:
631       ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
632       break;
633     case GeDataType::DT_INT16:
634       ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
635       break;
636     case GeDataType::DT_UINT16:
637       ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
638       break;
639     case GeDataType::DT_DUAL_SUB_INT8:
640     case GeDataType::DT_INT8:
641       ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
642       break;
643     case GeDataType::DT_UINT8:
644     case GeDataType::DT_DUAL_SUB_UINT8:
645       ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
646       break;
647     case GeDataType::DT_FLOAT16:
648     case GeDataType::DT_BOOL:
649     case GeDataType::DT_UNDEFINED:
650     case GeDataType::DT_DUAL:
651     default:
652       MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
653                     << " ge tensor";
654       break;
655   }
656   return ret;
657 }
658 
NormOpName(const std::string & anf_name)659 std::string TransformUtil::NormOpName(const std::string &anf_name) {
660   std::string str = anf_name.substr(anf_name.rfind("/") + 1);
661   std::string ret;
662   for (const auto &c : str) {
663     if (std::isalnum(c) || c == '_' || c == '-') {
664       ret += c;
665     }
666   }
667   return ret;
668 }
669 }  // namespace transform
670 }  // namespace mindspore
671