1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/transform_util.h"
18 #include <utility>
19 #include <map>
20 #include <algorithm>
21 #include <complex>
22
23 #include "include/common/utils/convert_utils.h"
24 #include "include/common/utils/utils.h"
25 #include "utils/shape_utils.h"
26 #include "transform/graph_ir/op_adapter_util.h"
27
28 #ifndef ENABLE_LITE_ACL
29 #include "include/common/utils/python_adapter.h"
30 #endif
31
32 namespace mindspore {
33 namespace transform {
34 using std::make_shared;
35 using std::shared_ptr;
36 using std::string;
37 using std::vector;
38
39 const size_t kErrorSize = 0;
40 const size_t kIdx0 = 0;
41 const size_t kIdx1 = 1;
42 const size_t kIdx2 = 2;
43 const size_t kIdx3 = 3;
44
45 namespace {
46 class MsTensorRel {
47 public:
MsTensorRel(const MeTensorPtr & tensor)48 explicit MsTensorRel(const MeTensorPtr &tensor) : tensor_(tensor) {}
49 ~MsTensorRel() = default;
Rel() const50 void Rel() const { tensor_ = nullptr; }
51
52 private:
53 mutable MeTensorPtr tensor_;
54 };
55 } // namespace
56
57 class TensorRefData : public tensor::TensorData {
58 public:
TensorRefData(void * data,ssize_t data_size,ssize_t itemsize,ssize_t ndim)59 TensorRefData(void *data, ssize_t data_size, ssize_t itemsize, ssize_t ndim)
60 : data_(data), data_size_(data_size), itemsize_(itemsize), ndim_(ndim) {}
61
62 ~TensorRefData() override = default;
63
64 // Total number of elements.
size() const65 ssize_t size() const override { return data_size_; }
66
67 // Byte size of a single element.
itemsize() const68 ssize_t itemsize() const override { return itemsize_; }
69
70 // Total number of bytes.
nbytes() const71 ssize_t nbytes() const override { return size() * itemsize(); }
72
73 // Number of dimensions.
ndim() const74 ssize_t ndim() const override { return ndim_; }
75
data()76 void *data() override { return data_; }
const_data() const77 const void *const_data() const override { return data_; }
78
is_sub_data() const79 bool is_sub_data() const override { return false; }
has_sub_data() const80 bool has_sub_data() const override { return false; }
81
ToString(TypeId type,const ShapeVector & shape,bool use_comma) const82 std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override { return ""; }
83
84 protected:
85 void *data_ = nullptr;
86 ssize_t data_size_ = 0;
87 ssize_t itemsize_ = 0;
88 ssize_t ndim_ = 0;
89 };
90
ConvertIntToList(int64_t data,int size)91 vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
92 vector<int64_t> list{};
93 if (size <= 0) {
94 MS_LOG(WARNING) << "size <= 0";
95 return list;
96 }
97 for (int i = 0; i < size; ++i) {
98 list.emplace_back(data);
99 }
100 return list;
101 }
102
103 static std::map<MeDataType, GeDataType> datatype_trans_map = {
104 {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16},
105 {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
106 {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE},
107 {MeDataType::kNumberTypeBFloat16, GeDataType::DT_BF16},
108 {MeDataType::kNumberTypeInt4, GeDataType::DT_INT4},
109 {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
110 {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16},
111 {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
112 {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64},
113 {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
114 {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16},
115 {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
116 {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64},
117 {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL},
118 {MeDataType::kObjectTypeString, GeDataType::DT_STRING},
119 {MeDataType::kNumberTypeFloat, GeDataType::DT_FLOAT},
120 {MeDataType::kNumberTypeComplex64, GeDataType::DT_COMPLEX64},
121 {MeDataType::kNumberTypeComplex128, GeDataType::DT_COMPLEX128}};
122
ConvertDataType(const MeDataType & type)123 GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
124 MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
125 if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
126 return datatype_trans_map[type];
127 } else {
128 return GeDataType::DT_UNDEFINED;
129 }
130 }
131
ConvertFormat(const string & format,const size_t shape_size)132 GeFormat TransformUtil::ConvertFormat(const string &format, const size_t shape_size) {
133 static constexpr size_t k4dSize = 4;
134 static const std::map<std::string, GeFormat> format_map = {
135 {kOpFormat_DEFAULT, GeFormat::FORMAT_NCHW},
136 {kOpFormat_NC1KHKWHWC0, GeFormat::FORMAT_NC1KHKWHWC0},
137 {kOpFormat_ND, GeFormat::FORMAT_ND},
138 {kOpFormat_NCHW, GeFormat::FORMAT_NCHW},
139 {kOpFormat_NHWC, GeFormat::FORMAT_NHWC},
140 {kOpFormat_HWCN, GeFormat::FORMAT_HWCN},
141 {kOpFormat_NC1HWC0, GeFormat::FORMAT_NC1HWC0},
142 {kOpFormat_FRAC_Z, GeFormat::FORMAT_FRACTAL_Z},
143 {kOpFormat_FRAC_NZ, GeFormat::FORMAT_FRACTAL_NZ},
144 {kOpFormat_C1HWNCoC0, GeFormat::FORMAT_C1HWNCoC0},
145 {kOpFormat_NC1HWC0_C04, GeFormat::FORMAT_NC1HWC0_C04},
146 {kOpFormat_FRACTAL_Z_C04, GeFormat::FORMAT_FRACTAL_Z_C04},
147 {kOpFormat_NDHWC, GeFormat::FORMAT_NDHWC},
148 {kOpFormat_NCDHW, GeFormat::FORMAT_NCDHW},
149 {kOpFormat_DHWNC, GeFormat::FORMAT_DHWNC},
150 {kOpFormat_DHWCN, GeFormat::FORMAT_DHWCN},
151 {kOpFormat_NDC1HWC0, GeFormat::FORMAT_NDC1HWC0},
152 {kOpFormat_FRACTAL_Z_3D, GeFormat::FORMAT_FRACTAL_Z_3D},
153 {kOpFormat_FRACTAL_ZN_LSTM, GeFormat::FORMAT_FRACTAL_ZN_LSTM},
154 {kOpFormat_ND_RNN_BIAS, GeFormat::FORMAT_ND_RNN_BIAS},
155 {kOpFormat_FRACTAL_ZN_RNN, GeFormat::FORMAT_FRACTAL_ZN_RNN}};
156 if (format == kOpFormat_DEFAULT) {
157 return shape_size == k4dSize ? GeFormat::FORMAT_NCHW : GeFormat::FORMAT_ND;
158 }
159 auto iter = format_map.find(format);
160 if (iter == format_map.end()) {
161 MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
162 return GeFormat::FORMAT_ND;
163 }
164 return iter->second;
165 }
166
GetGeTensorDesc(const ShapeVector & ori_shape,const MeDataType & me_type,const std::string & ori_format,const ShapeVector & dev_shape,const std::string & dev_format)167 std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &ori_shape, const MeDataType &me_type,
168 const std::string &ori_format,
169 const ShapeVector &dev_shape,
170 const std::string &dev_format) {
171 // convert me shape to ge shape
172 GeShape ori_ge_shape(ori_shape);
173 if (ori_ge_shape.GetDimNum() == 0) {
174 MS_LOG(DEBUG) << "The dims size of Ge tensor is zero";
175 }
176 // convert me format to ge format
177 GeFormat ori_ge_format = ConvertFormat(ori_format, ori_shape.size());
178 if (ori_ge_format == GeFormat::FORMAT_ND) {
179 MS_LOG(DEBUG) << "Set ND data format";
180 }
181 // convert me datatype to ge datatype
182 GeDataType data_type = ConvertDataType(me_type);
183 if (data_type == GeDataType::DT_UNDEFINED) {
184 MS_LOG(WARNING) << "undefined data type :" << me_type;
185 return nullptr;
186 }
187 auto desc = std::make_shared<GeTensorDesc>();
188 if (desc == nullptr) {
189 MS_LOG(ERROR) << "Create GeTensorDesc failed!";
190 return nullptr;
191 }
192 // set ori shape and format.
193 // note: if ori_shape and ori_format have been set. the set_shape and set_format will run as device info, otherwise
194 // the set_shape and set_format will run as host info.
195 if (!std::any_of(ori_shape.cbegin(), ori_shape.cend(), [](const auto &dim) { return dim < 0; })) {
196 desc->SetOriginShape(ori_ge_shape);
197 desc->SetOriginFormat(ori_ge_format);
198 }
199 desc->SetDataType(data_type);
200
201 // set device shape and format, if value is empty, use ori shape and format replace.
202 auto dev_ge_shape = dev_shape.empty() ? ori_ge_shape : GeShape(dev_shape);
203 GeFormat dev_ge_format = dev_format.empty() ? ori_ge_format : ConvertFormat(dev_format, dev_ge_shape.GetDimNum());
204 if (me_type == MeDataType::kNumberTypeInt4) {
205 int64_t last_dim = dev_ge_shape.GetDimNum() - 1;
206 dev_ge_shape.SetDim(last_dim, dev_ge_shape.GetDim(last_dim) * 2);
207 }
208 desc->SetShape(dev_ge_shape);
209 desc->SetFormat(dev_ge_format);
210
211 MS_LOG(DEBUG) << "SetRealDimCnt is :" << ori_shape.size();
212 desc->SetRealDimCnt(SizeToInt(ori_shape.size()));
213 return desc;
214 }
215
216 // if failed, return empty vector.
ConvertInputTensors(const std::vector<MeTensorPtr> & me_tensors,const std::string & format)217 std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
218 const std::string &format) {
219 std::vector<GeTensorPtr> ge_tensors;
220
221 for (size_t index = 0; index < me_tensors.size(); index++) {
222 MS_EXCEPTION_IF_NULL(me_tensors[index]);
223 MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
224 auto shape = me_tensors[index]->shape();
225 std::string shape_str;
226 for (size_t i = 0; i < shape.size(); i++) {
227 shape_str += std::to_string(shape[i]);
228 shape_str += " ";
229 }
230 MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
231 MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
232
233 auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
234 if (ge_tensor_ptr != nullptr) {
235 (void)ge_tensors.emplace_back(ge_tensor_ptr);
236 } else {
237 MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
238 ge_tensors.clear();
239 return ge_tensors;
240 }
241 }
242 return ge_tensors;
243 }
244
245 #ifndef ENABLE_LITE_ACL
ConvertStringTensor(const MeTensorPtr & tensor,const std::string & format)246 GeTensorPtr ConvertStringTensor(const MeTensorPtr &tensor, const std::string &format) {
247 auto desc = TransformUtil::GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
248 if (desc == nullptr) {
249 MS_LOG(ERROR) << "Failed to get Tensor Desc";
250 return nullptr;
251 }
252 GeTensorPtr tensor_ptr = nullptr;
253 auto data_buff_size = tensor->data().nbytes();
254 py::gil_scoped_acquire gil;
255 auto py_array = python_adapter::PyAdapterCallback::TensorToNumpy(*tensor);
256 auto buf = py_array.request();
257 auto data_ptr = static_cast<char *>(tensor->data().data());
258 size_t single_char_offset = 4;
259
260 if (buf.format.back() == 'w') {
261 auto max_length = buf.format.substr(0, buf.format.length() - 1);
262 int64_t max_length_long = 0;
263 try {
264 max_length_long = std::stol(max_length);
265 } catch (const std::exception &e) {
266 MS_LOG(EXCEPTION) << "Invalid argument:" << e.what() << " when parse " << max_length;
267 }
268 auto string_max_length = LongToSize(max_length_long);
269 if (string_max_length == 0) {
270 MS_LOG(ERROR) << "Failed to get Tensor Desc. Please check string length";
271 return nullptr;
272 }
273 size_t elements_num = (data_buff_size / single_char_offset) / string_max_length;
274 std::vector<std::string> string_vector;
275 char *string_element = new char[string_max_length];
276 size_t string_length = 0;
277 for (size_t i = 0; i < elements_num; i++) {
278 (void)std::fill_n(string_element, string_max_length, '\0');
279 for (size_t j = 0; j < string_max_length; j++) {
280 char char_element = data_ptr[i * string_max_length * single_char_offset + single_char_offset * j];
281 if (static_cast<int>(char_element) == 0) {
282 break;
283 } else {
284 string_element[j] = char_element;
285 string_length += 1;
286 }
287 }
288 std::string string_to_add(string_element, string_length);
289 (void)string_vector.emplace_back(string_to_add);
290 }
291 delete[] string_element;
292 string_element = nullptr;
293 tensor_ptr = make_shared<GeTensor>(*desc);
294 (void)tensor_ptr->SetData(string_vector);
295 } else {
296 int64_t length_long = 0;
297 try {
298 length_long = std::stol(buf.format.substr(0, buf.format.length() - 1));
299 } catch (const std::exception &e) {
300 MS_LOG(EXCEPTION) << "Invalid argument:" << e.what() << " when parse "
301 << buf.format.substr(0, buf.format.length() - 1);
302 }
303 auto string_length = LongToSize(length_long);
304 if (string_length == 0) {
305 MS_LOG(ERROR) << "Failed to get Tensor Desc. Please check string length";
306 return nullptr;
307 }
308 char *string_element = new char[string_length];
309 for (size_t i = 0; i < string_length; i++) {
310 string_element[i] = data_ptr[i];
311 }
312 std::string string_to_add(string_element, string_length);
313 tensor_ptr = make_shared<GeTensor>(*desc);
314 (void)tensor_ptr->SetData(string_to_add);
315 delete[] string_element;
316 string_element = nullptr;
317 }
318 return tensor_ptr;
319 }
320 #endif
321
ConvertTensor(const MeTensorPtr & tensor,const std::string & format,bool copy)322 GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format, bool copy) {
323 // get tensor data type size
324 MS_EXCEPTION_IF_NULL(tensor);
325 auto me_data_type = tensor->data_type();
326 #ifndef ENABLE_LITE_ACL
327 if (me_data_type == mindspore::kObjectTypeString) {
328 return ConvertStringTensor(tensor, format);
329 }
330 #endif
331 size_t type_size = GetDataTypeSize(me_data_type);
332 if (type_size == kErrorSize) {
333 MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
334 return nullptr;
335 }
336
337 // get tensor buff size
338 size_t data_buff_size = tensor->Size();
339 if (data_buff_size == 0) {
340 MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
341 }
342 // create ge tensor
343 auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
344 if (desc == nullptr) {
345 MS_LOG(ERROR) << "Failed to get Tensor Desc";
346 return nullptr;
347 }
348 GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc);
349 if (tensor_ptr == nullptr) {
350 MS_LOG(ERROR) << "Failed to convert Me Tensor to Ge Tensor!";
351 return nullptr;
352 }
353 if (copy) {
354 auto ret = tensor_ptr->SetData(static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
355 if (ret != ge::GRAPH_SUCCESS) {
356 MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(const uint8_t*, size), data size " << data_buff_size;
357 return nullptr;
358 }
359 } else {
360 MsTensorRel rel(tensor);
361 auto ret = tensor_ptr->SetData(static_cast<uint8_t *>(tensor->data_c()), data_buff_size,
362 [rel](uint8_t *) -> void { rel.Rel(); });
363 if (ret != ge::GRAPH_SUCCESS) {
364 MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << data_buff_size;
365 return nullptr;
366 }
367 }
368 MS_LOG(DEBUG) << "Convert Me Tensor to Ge Tensor success!";
369 return tensor_ptr;
370 }
371
ConvertScalar(const ValuePtr & val)372 GeTensorPtr TransformUtil::ConvertScalar(const ValuePtr &val) {
373 auto ge_tensor = ConvertAnyUtil(val, AnyTraits<ValueAny>());
374 return make_shared<GeTensor>(ge_tensor);
375 }
376
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors,const std::vector<ShapeVector> & request_dims)377 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
378 const std::vector<ShapeVector> &request_dims) {
379 std::vector<MeTensorPtr> outputs;
380
381 for (size_t index = 0; index < ge_tensors.size(); index++) {
382 MeTensorPtr me_tensor_ptr = nullptr;
383 if (index < request_dims.size()) {
384 me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
385 } else {
386 ShapeVector empty_shape;
387 me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
388 }
389
390 if (me_tensor_ptr != nullptr) {
391 (void)outputs.emplace_back(me_tensor_ptr);
392 } else {
393 MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
394 return outputs;
395 }
396 }
397 return outputs;
398 }
399
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors)400 std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
401 std::vector<MeTensorPtr> outputs;
402
403 for (size_t index = 0; index < ge_tensors.size(); index++) {
404 MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
405 if (me_tensor_ptr != nullptr) {
406 (void)outputs.emplace_back(me_tensor_ptr);
407 } else {
408 MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
409 return outputs;
410 }
411 }
412 return outputs;
413 }
414
ConvertGeDataType(const GeDataType & type)415 MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
416 switch (type) {
417 case GeDataType::DT_FLOAT16:
418 return MeDataType::kNumberTypeFloat16;
419 case GeDataType::DT_BF16:
420 return MeDataType::kNumberTypeBFloat16;
421 case GeDataType::DT_FLOAT:
422 return MeDataType::kNumberTypeFloat32;
423 case GeDataType::DT_DOUBLE:
424 return MeDataType::kNumberTypeFloat64;
425 case GeDataType::DT_INT64:
426 return MeDataType::kNumberTypeInt64;
427 case GeDataType::DT_INT32:
428 return MeDataType::kNumberTypeInt32;
429 case GeDataType::DT_INT16:
430 return MeDataType::kNumberTypeInt16;
431 case GeDataType::DT_INT8:
432 return MeDataType::kNumberTypeInt8;
433 case GeDataType::DT_BOOL:
434 return MeDataType::kNumberTypeBool;
435 case GeDataType::DT_UINT8:
436 return MeDataType::kNumberTypeUInt8;
437 case GeDataType::DT_UINT16:
438 return MeDataType::kNumberTypeUInt16;
439 case GeDataType::DT_UINT32:
440 return MeDataType::kNumberTypeUInt32;
441 case GeDataType::DT_UINT64:
442 return MeDataType::kNumberTypeUInt64;
443 case GeDataType::DT_UNDEFINED:
444 case GeDataType::DT_DUAL_SUB_UINT8:
445 case GeDataType::DT_DUAL_SUB_INT8:
446 case GeDataType::DT_DUAL:
447 return MeDataType::kTypeUnknown;
448 default:
449 return MeDataType::kTypeUnknown;
450 }
451 }
452
453 namespace {
IsGeShapeCompatible(const GeShape & ge_shape,const ShapeVector & request_dims)454 bool IsGeShapeCompatible(const GeShape &ge_shape, const ShapeVector &request_dims) {
455 MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
456 MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
457
458 const int GE_DIMS = 4;
459 std::vector<int64_t> ge_dims = ge_shape.GetDims();
460 if (request_dims.size() > ge_dims.size()) {
461 MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
462 return false;
463 }
464
465 // convert NHWC to NCHW
466 if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[kIdx0] == ge_dims[kIdx1]) &&
467 (ge_dims[kIdx0] == 1) && (ge_dims[kIdx2] == 1) && (ge_dims[kIdx3] == 1)) {
468 MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
469 return true;
470 }
471
472 std::string::size_type i = 0;
473 for (; i < request_dims.size(); i++) {
474 if (ge_dims[i] != request_dims[i]) {
475 MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
476 return false;
477 }
478 }
479
480 for (; i < ge_dims.size(); i++) {
481 if (ge_dims[i] != 1) {
482 MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
483 return false;
484 }
485 }
486 MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
487 return true;
488 }
489 } // namespace
490
ConvertMeShape(const ShapeVector & me_dims)491 GeShape TransformUtil::ConvertMeShape(const ShapeVector &me_dims) {
492 std::vector<int64_t> ge_dims;
493 (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
494 return GeShape(ge_dims);
495 }
496
ConvertGeShape(const GeShape & ge_shape)497 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
498 ShapeVector me_dims;
499 std::vector<int64_t> ge_dims = ge_shape.GetDims();
500 (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
501 return me_dims;
502 }
503
ConvertGeShape(const GeShape & ge_shape,const ShapeVector & request_dims)504 ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims) {
505 vector<int64_t> ret;
506 if (ge_shape.GetDimNum() == 0) {
507 MS_LOG(DEBUG) << "GeTensor's shape is scalar";
508 return ret;
509 }
510
511 if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
512 ret = request_dims;
513 } else {
514 MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
515 ret = ConvertGeShape(ge_shape);
516 }
517 return ret;
518 }
519
GenerateMeTensor(const GeTensorPtr & ge_tensor,const ShapeVector & me_dims,const TypeId & me_type,bool ref_mem)520 MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims,
521 const TypeId &me_type, bool ref_mem) {
522 MS_EXCEPTION_IF_NULL(ge_tensor);
523 MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
524 if (ge_tensor->GetSize() == 0) {
525 MS_LOG(ERROR) << "GE tensor data size is zero!";
526 return nullptr;
527 }
528
529 if (ref_mem) {
530 void *data = reinterpret_cast<void *>(const_cast<uint8_t *>(ge_tensor->GetData()));
531 ssize_t data_size = static_cast<ssize_t>(SizeOf(me_dims));
532 ssize_t itemsize = MeTensor(me_type, ShapeVector()).data().itemsize();
533 ssize_t ndim = static_cast<ssize_t>(me_dims.size());
534 auto ref_data = std::make_shared<TensorRefData>(data, data_size, itemsize, ndim);
535 return make_shared<MeTensor>(me_type, me_dims, ref_data);
536 } else {
537 MeTensor me_tensor(me_type, me_dims);
538
539 // Get the writable data pointer of the tensor and cast it to its data type.
540 auto me_data_ptr = me_tensor.data_c();
541 size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
542 MS_EXCEPTION_IF_NULL(me_data_ptr);
543 size_t length = ge_tensor->GetSize();
544 if (me_data_size < length) {
545 MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor [" << length
546 << " bytes]";
547 return nullptr;
548 }
549
550 if (length < SECUREC_MEM_MAX_LEN) {
551 int ret_code = memcpy_s(me_data_ptr, length, ge_tensor->GetData(), length);
552 if (ret_code != EOK) {
553 MS_LOG(ERROR) << "Memcpy_s from ge_tensor to me_tensor failed.";
554 return nullptr;
555 }
556 } else {
557 (void)memcpy(me_data_ptr, ge_tensor->GetData(), length);
558 }
559
560 return make_shared<MeTensor>(me_tensor);
561 }
562 }
563
ConvertGeTensor(const GeTensorPtr & ge_tensor)564 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
565 MS_EXCEPTION_IF_NULL(ge_tensor);
566 GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
567 vector<int64_t> me_dims = ConvertGeShape(ge_shape);
568
569 TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
570 if (type_id == MeDataType::kTypeUnknown) {
571 MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
572 << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
573 return nullptr;
574 }
575 return GenerateMeTensor(ge_tensor, me_dims, type_id);
576 }
577
ConvertGeTensor(const GeTensorPtr & ge_tensor,const TypeId & me_type)578 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor, const TypeId &me_type) {
579 MS_EXCEPTION_IF_NULL(ge_tensor);
580 GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
581 vector<int64_t> me_dims = ConvertGeShape(ge_shape);
582
583 if (me_type == MeDataType::kTypeUnknown) {
584 MS_LOG(ERROR) << "Unsupported data type: " << static_cast<int>(me_type);
585 return nullptr;
586 }
587 return GenerateMeTensor(ge_tensor, me_dims, me_type);
588 }
589
590 // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
ConvertGeTensor(const GeTensorPtr ge_tensor,const ShapeVector & request_dims,bool ref_mem)591 MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const ShapeVector &request_dims, bool ref_mem) {
592 MS_EXCEPTION_IF_NULL(ge_tensor);
593 GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
594 vector<int64_t> me_dims = ConvertGeShape(ge_shape, request_dims);
595 MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
596 // Create a tensor with wanted data type and shape
597 TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
598 if (type_id == MeDataType::kTypeUnknown) {
599 MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
600 << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
601 return nullptr;
602 }
603 return GenerateMeTensor(ge_tensor, me_dims, type_id, ref_mem);
604 }
605
PrintGeTensor(const GeTensorPtr ge_tensor)606 std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
607 std::string ret;
608 if (ge_tensor == nullptr) {
609 MS_LOG(ERROR) << "Input ge tensor is nullptr";
610 return ret;
611 }
612
613 MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
614 switch (static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())) {
615 case GeDataType::DT_UINT32:
616 ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
617 break;
618 case GeDataType::DT_FLOAT:
619 ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
620 break;
621 case GeDataType::DT_INT32:
622 ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
623 break;
624 case GeDataType::DT_DOUBLE:
625 ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
626 break;
627 case GeDataType::DT_INT64:
628 ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
629 break;
630 case GeDataType::DT_UINT64:
631 ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
632 break;
633 case GeDataType::DT_INT16:
634 ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
635 break;
636 case GeDataType::DT_UINT16:
637 ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
638 break;
639 case GeDataType::DT_DUAL_SUB_INT8:
640 case GeDataType::DT_INT8:
641 ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
642 break;
643 case GeDataType::DT_UINT8:
644 case GeDataType::DT_DUAL_SUB_UINT8:
645 ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
646 break;
647 case GeDataType::DT_FLOAT16:
648 case GeDataType::DT_BOOL:
649 case GeDataType::DT_UNDEFINED:
650 case GeDataType::DT_DUAL:
651 default:
652 MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
653 << " ge tensor";
654 break;
655 }
656 return ret;
657 }
658
NormOpName(const std::string & anf_name)659 std::string TransformUtil::NormOpName(const std::string &anf_name) {
660 std::string str = anf_name.substr(anf_name.rfind("/") + 1);
661 std::string ret;
662 for (const auto &c : str) {
663 if (std::isalnum(c) || c == '_' || c == '-') {
664 ret += c;
665 }
666 }
667 return ret;
668 }
669 } // namespace transform
670 } // namespace mindspore
671