1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/tensor.h"
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #ifdef SUPPORT_NNRT
22 #include "src/litert/delegate/nnrt/nnrt_allocator.h"
23 #endif
24 #include "schema/ops_types_generated.h"
25 #include "securec/include/securec.h"
26 #include "include/errorcode.h"
27
28 namespace mindspore {
29 namespace lite {
30 namespace {
31 static const size_t max_malloc_size_ = GetMaxMallocSize();
32 }
33 #if ENABLE_HIGH_PERFORMANCE
34 #define CHECK_INT64_MUL_OVERFLOW(x, y)
35 #else
36 #define CHECK_INT64_MUL_OVERFLOW(x, y) \
37 do { \
38 if (INT64_MUL_OVERFLOW(x, y)) { \
39 MS_LOG(ERROR) << "INT64 MUL OVERFLOW"; \
40 return INT64_MAX; \
41 } \
42 } while (0)
43
44 #define INT64_MUL_OVERFLOW(x, y) \
45 (((x) == 0) ? false \
46 : ((x) > 0 ? (((y) >= 0) ? (INT64_MAX / (x)) < (y) : (INT64_MAX / (x)) < (-1 * (y))) \
47 : (((y) >= 0) ? (INT64_MAX / (x)) > (-1 * (y)) : (INT64_MAX / (x)) > (y))))
48 #endif
49
Tensor(const TypeId data_type,std::vector<int> shape,const mindspore::Format & format,Category category)50 Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const mindspore::Format &format, Category category) {
51 tensor_c_ = {false, data_type, static_cast<int>(format), category, nullptr, shape.size()};
52 if (shape.size() > MAX_SHAPE_SIZE) {
53 tensor_c_.shape_size_ = 0;
54 MS_LOG(WARNING) << "The shape-size has exceeded the limit 8, now is " << shape.size();
55 return;
56 }
57 for (size_t i = 0; i < shape.size(); ++i) {
58 tensor_c_.shape_[i] = shape[i];
59 }
60 }
61
CopyTensorData(const Tensor & src_tensor,Tensor * dst_tensor)62 int Tensor::CopyTensorData(const Tensor &src_tensor, Tensor *dst_tensor) {
63 if (dst_tensor == nullptr) {
64 MS_LOG(ERROR) << "dst_tensor is nullptr";
65 return RET_PARAM_INVALID;
66 }
67 if (src_tensor.tensor_c_.data_ == nullptr) {
68 MS_LOG(INFO) << "data of src tensor is nullptr";
69 return RET_OK;
70 }
71 size_t data_size = dst_tensor->Size();
72 if (data_size != src_tensor.Size()) {
73 MS_LOG(ERROR) << "Size of dst tensor is not compatible with src tensor";
74 return RET_ERROR;
75 }
76 if (dst_tensor->MallocData() != RET_OK) {
77 MS_LOG(ERROR) << "Malloc memory failed";
78 return RET_ERROR;
79 }
80 dst_tensor->ResetRefCount();
81 (void)memcpy(dst_tensor->tensor_c_.data_, src_tensor.tensor_c_.data_, data_size);
82 return RET_OK;
83 }
84
CopyTensor(const Tensor & src_tensor,bool copy_data,AllocatorPtr allocator)85 Tensor *Tensor::CopyTensor(const Tensor &src_tensor, bool copy_data, AllocatorPtr allocator) {
86 auto *result = new (std::nothrow) Tensor;
87 if (result == nullptr) {
88 MS_LOG(ERROR) << "New tensor failed";
89 return nullptr;
90 }
91 (void)memcpy(&result->tensor_c_, &src_tensor.tensor_c_, sizeof(TensorC));
92 result->tensor_c_.data_ = nullptr;
93 result->compress_type_ = src_tensor.compress_type_;
94 result->compressed_size_ = src_tensor.compressed_size_;
95 result->set_allocator(allocator);
96 result->set_tensor_name(src_tensor.tensor_name() + "_duplicate");
97 if (copy_data) {
98 auto ret = CopyTensorData(src_tensor, result);
99 if (ret != RET_OK) {
100 MS_LOG(ERROR) << "CopyTensorData error";
101 delete result;
102 return nullptr;
103 }
104 result->own_data_ = src_tensor.own_data_;
105 }
106
107 for (const LiteQuantParam &quant : src_tensor.quant_params()) {
108 result->AddQuantParam(quant);
109 }
110
111 return result;
112 }
113
~Tensor()114 Tensor::~Tensor() {
115 #ifdef SUPPORT_NNRT
116 void *allocated_data = this->tensor_c_.data_;
117 #endif
118 FreeData();
119 this->tensor_c_.data_ = nullptr;
120 #ifdef SUPPORT_NNRT
121 if (this->own_data_ && IS_NNRT_ALLOCATOR(allocator_)) {
122 NNRTAllocator::GetInstance()->FreeAllocatedTensor(allocated_data, this);
123 }
124 #endif
125 }
126
operator ==(const Tensor & tensor)127 bool Tensor::operator==(const Tensor &tensor) {
128 return tensor_c_.data_ == tensor.tensor_c_.data_ && tensor_c_.shape_size_ == tensor.tensor_c_.shape_size_ &&
129 tensor_c_.data_type_ == tensor.tensor_c_.data_type_ &&
130 std::equal(tensor_c_.shape_, tensor_c_.shape_ + tensor_c_.shape_size_, tensor.tensor_c_.shape_);
131 }
132
Batch() const133 int32_t Tensor::Batch() const {
134 // Only 2D or 4D tensors have valid batch.
135 if (this->tensor_c_.shape_size_ != C4NUM && this->tensor_c_.shape_size_ != C2NUM) {
136 MS_LOG(ERROR) << "Unsupported tensor shape: " << this->tensor_c_.shape_size_;
137 return RET_ERROR;
138 }
139 switch (this->tensor_c_.format_) {
140 case mindspore::NHWC:
141 case mindspore::NHWC4:
142 case mindspore::NCHW:
143 case mindspore::NC4HW4:
144 case mindspore::NC8HW8:
145 case mindspore::KCHW:
146 case mindspore::KHWC:
147 case mindspore::NC:
148 case mindspore::NC4:
149 return this->tensor_c_.shape_[0];
150 case mindspore::HWCK:
151 case mindspore::CHWK:
152 if (this->tensor_c_.shape_size_ != C4NUM) {
153 return RET_ERROR;
154 }
155 return this->tensor_c_.shape_[C3NUM];
156 case mindspore::HWKC:
157 if (this->tensor_c_.shape_size_ != C4NUM) {
158 return RET_ERROR;
159 }
160 return this->tensor_c_.shape_[C2NUM];
161 case mindspore::CKHW:
162 return this->tensor_c_.shape_[1];
163 default:
164 MS_LOG(ERROR) << "Unsupported format: " << EnumNameFormat(static_cast<schema::Format>(this->tensor_c_.format_));
165 return RET_ERROR;
166 }
167 }
168
Channel() const169 int32_t Tensor::Channel() const {
170 // Only 2D or 4D tensors have valid channel.
171 if (this->tensor_c_.shape_size_ != C4NUM && this->tensor_c_.shape_size_ != C2NUM) {
172 MS_LOG(ERROR) << "Unsupported tensor shape: " << this->tensor_c_.shape_size_;
173 return RET_ERROR;
174 }
175 switch (this->tensor_c_.format_) {
176 case mindspore::NCHW:
177 case mindspore::KCHW:
178 case mindspore::NC:
179 case mindspore::NC4:
180 case mindspore::NC4HW4:
181 case mindspore::NC8HW8:
182 return this->tensor_c_.shape_[1];
183 case mindspore::HWCK:
184 if (this->tensor_c_.shape_size_ != C4NUM) {
185 return RET_ERROR;
186 }
187 return this->tensor_c_.shape_[C2NUM];
188 case mindspore::HWKC:
189 case mindspore::NHWC:
190 case mindspore::NHWC4:
191 case mindspore::KHWC:
192 if (this->tensor_c_.shape_size_ != C4NUM) {
193 return RET_ERROR;
194 }
195 return this->tensor_c_.shape_[C3NUM];
196 case mindspore::CKHW:
197 case mindspore::CHWK:
198 return this->tensor_c_.shape_[0];
199 default:
200 return RET_ERROR;
201 }
202 }
203
Height() const204 int32_t Tensor::Height() const {
205 // Only 2D or 4D tensors have valid height.
206 if (this->tensor_c_.shape_size_ != C4NUM && this->tensor_c_.shape_size_ != C2NUM) {
207 MS_LOG(ERROR) << "Unsupported tensor shape: " << this->tensor_c_.shape_size_;
208 return RET_ERROR;
209 }
210 switch (this->tensor_c_.format_) {
211 case mindspore::NCHW:
212 case mindspore::KCHW:
213 case mindspore::CKHW:
214 case mindspore::NC4HW4:
215 case mindspore::NC8HW8:
216 if (this->tensor_c_.shape_size_ != C4NUM) {
217 return RET_ERROR;
218 }
219 return this->tensor_c_.shape_[C2NUM];
220 case mindspore::NHWC:
221 case mindspore::NHWC4:
222 case mindspore::KHWC:
223 case mindspore::CHWK:
224 return this->tensor_c_.shape_[1];
225 case mindspore::HWCK:
226 case mindspore::HWKC:
227 case mindspore::HW:
228 case mindspore::HW4:
229 return this->tensor_c_.shape_[0];
230 default:
231 MS_LOG(ERROR) << "Unsupported format: " << EnumNameFormat(static_cast<schema::Format>(this->tensor_c_.format_));
232 return RET_ERROR;
233 }
234 }
235
Width() const236 int32_t Tensor::Width() const {
237 // Only 2D or 4D tensors have valid width.
238 if (this->tensor_c_.shape_size_ != C4NUM && this->tensor_c_.shape_size_ != C2NUM) {
239 MS_LOG(ERROR) << "Unsupported tensor shape: " << this->tensor_c_.shape_size_;
240 return RET_ERROR;
241 }
242 switch (this->tensor_c_.format_) {
243 case mindspore::NCHW:
244 case mindspore::KCHW:
245 case mindspore::CKHW:
246 case mindspore::NC4HW4:
247 case mindspore::NC8HW8:
248 if (this->tensor_c_.shape_size_ != C4NUM) {
249 return RET_ERROR;
250 }
251 return this->tensor_c_.shape_[C3NUM];
252 case mindspore::KHWC:
253 case mindspore::NHWC:
254 case mindspore::NHWC4:
255 case mindspore::CHWK:
256 if (this->tensor_c_.shape_size_ != C4NUM) {
257 return RET_ERROR;
258 }
259 return this->tensor_c_.shape_[C2NUM];
260 case mindspore::HWCK:
261 case mindspore::HWKC:
262 case mindspore::HW:
263 case mindspore::HW4:
264 return this->tensor_c_.shape_[1];
265 default:
266 return RET_ERROR;
267 }
268 }
269
Size() const270 size_t Tensor::Size() const {
271 if (compress_type_ != kNoCompression) {
272 return compressed_size_;
273 } else {
274 size_t element_size = DataTypeSize(static_cast<TypeId>(tensor_c_.data_type_));
275 if (element_size == 0) {
276 MS_LOG(INFO) << "Unexpected data type: " << tensor_c_.data_type_;
277 return 0;
278 }
279 auto element_num = (tensor_c_.format_ == mindspore::NC4HW4 || tensor_c_.format_ == mindspore::NHWC4)
280 ? ElementsC4Num()
281 : ElementsNum();
282 if (element_num <= 0) {
283 std::vector<int> shape(tensor_c_.shape_, tensor_c_.shape_ + tensor_c_.shape_size_);
284 MS_LOG(DEBUG) << "Element number of tensor should large than 0 : " << element_num << ", shape: " << shape;
285 return 0;
286 }
287 return element_size * static_cast<size_t>(element_num);
288 }
289 }
290
ElementsNum() const291 int64_t Tensor::ElementsNum() const {
292 if (this->tensor_c_.category_ == CONST_SCALAR) {
293 return 1;
294 }
295 if (tensor_c_.format_ == mindspore::NC4HW4) {
296 return ElementsC4Num();
297 }
298 if (tensor_c_.format_ == mindspore::NC8HW8) {
299 return ElementsC8Num();
300 }
301 int64_t num = 1;
302 for (size_t i = 0; i < tensor_c_.shape_size_; ++i) {
303 if (tensor_c_.shape_[i] < 0) {
304 return 0;
305 }
306 CHECK_INT64_MUL_OVERFLOW(num, tensor_c_.shape_[i]);
307 num *= tensor_c_.shape_[i];
308 }
309 return num;
310 }
311
ElementsC4Num() const312 int64_t Tensor::ElementsC4Num() const {
313 if (this->tensor_c_.category_ == CONST_SCALAR) {
314 return 1;
315 }
316 int64_t result = 1;
317 constexpr int kC4Align = 4;
318 if (this->tensor_c_.shape_size_ == C4NUM) {
319 CHECK_INT64_MUL_OVERFLOW(result, Batch());
320 result *= Batch();
321 CHECK_INT64_MUL_OVERFLOW(result, Height());
322 result *= Height();
323 CHECK_INT64_MUL_OVERFLOW(result, Width());
324 result *= Width();
325 CHECK_INT64_MUL_OVERFLOW(result, (Channel() + 3LL) / kC4Align * kC4Align);
326 result *= (Channel() + 3LL) / kC4Align * kC4Align;
327 } else if (this->tensor_c_.shape_size_ == 3) { // 3 : [H W C]
328 CHECK_INT64_MUL_OVERFLOW(result, this->tensor_c_.shape_[0]);
329 result *= this->tensor_c_.shape_[0];
330 CHECK_INT64_MUL_OVERFLOW(result, this->tensor_c_.shape_[1]);
331 result *= this->tensor_c_.shape_[1];
332 CHECK_INT64_MUL_OVERFLOW(result, (this->tensor_c_.shape_[2] + 3LL) / kC4Align * kC4Align); // C : 2
333 result *= (this->tensor_c_.shape_[2] + 3LL) / kC4Align * kC4Align; // C : 2
334 } else if (this->tensor_c_.shape_size_ == 2) { // 2 : [W C]
335 CHECK_INT64_MUL_OVERFLOW(result, this->tensor_c_.shape_[0]);
336 result *= this->tensor_c_.shape_[0];
337 CHECK_INT64_MUL_OVERFLOW(result, (this->tensor_c_.shape_[1] + 3LL) / kC4Align * kC4Align);
338 result *= (this->tensor_c_.shape_[1] + 3LL) / kC4Align * kC4Align;
339 } else if (this->tensor_c_.shape_size_ == 1) { // 1 : C
340 CHECK_INT64_MUL_OVERFLOW(result, (this->tensor_c_.shape_[0] + 3LL) / kC4Align * kC4Align);
341 result *= (this->tensor_c_.shape_[0] + 3LL) / kC4Align * kC4Align;
342 } else {
343 MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
344 }
345 return result;
346 }
347
ElementsC8Num() const348 int64_t Tensor::ElementsC8Num() const {
349 if (this->tensor_c_.category_ == CONST_SCALAR) {
350 return 1;
351 }
352 int64_t result = 1;
353 constexpr int kC8Align = 8;
354 if (this->tensor_c_.shape_size_ == C4NUM) {
355 CHECK_INT64_MUL_OVERFLOW(result, Batch());
356 result *= Batch();
357 CHECK_INT64_MUL_OVERFLOW(result, Height());
358 result *= Height();
359 CHECK_INT64_MUL_OVERFLOW(result, Width());
360 result *= Width();
361 CHECK_INT64_MUL_OVERFLOW(result, (Channel() + 7LL) / kC8Align * kC8Align);
362 result *= (Channel() + 7LL) / kC8Align * kC8Align;
363 } else if (this->tensor_c_.shape_size_ == C2NUM) {
364 CHECK_INT64_MUL_OVERFLOW(result, this->tensor_c_.shape_[0]);
365 result *= this->tensor_c_.shape_[0];
366 CHECK_INT64_MUL_OVERFLOW(result, (this->tensor_c_.shape_[1] + 7LL) / kC8Align * kC8Align);
367 result *= (this->tensor_c_.shape_[1] + 7LL) / kC8Align * kC8Align;
368 }
369 return result;
370 }
371
DimensionSize(const size_t index) const372 int Tensor::DimensionSize(const size_t index) const {
373 int dim_size = -1;
374 if (index < tensor_c_.shape_size_) {
375 dim_size = tensor_c_.shape_[index];
376 } else {
377 MS_LOG(ERROR) << "Dimension index is wrong: " << index;
378 }
379 return dim_size;
380 }
381
ToString() const382 std::string Tensor::ToString() const {
383 std::ostringstream oss;
384 oss << "Tensor name: " << this->tensor_name();
385 oss << " schema::Format: " << EnumNameFormat(static_cast<schema::Format>(this->tensor_c_.format_));
386 oss << " DataType: " << this->tensor_c_.data_type_;
387 oss << " Category: " << this->tensor_c_.category_;
388 oss << " Shape:";
389 for (auto &dim : this->shape()) {
390 oss << " " << dim;
391 }
392 oss << std::endl << "Data:";
393 auto data = tensor_c_.data_;
394 switch (this->tensor_c_.data_type_) {
395 case kNumberTypeFloat32: {
396 oss << DataToString<float>(data, this->ElementsNum());
397 } break;
398 case kNumberTypeFloat16: {
399 oss << DataToString<int16_t>(data, this->ElementsNum());
400 } break;
401 case kNumberTypeInt32: {
402 oss << DataToString<int32_t>(data, this->ElementsNum());
403 } break;
404 case kNumberTypeInt16: {
405 oss << DataToString<int16_t>(data, this->ElementsNum());
406 } break;
407 case kNumberTypeInt8: {
408 oss << DataToString<int8_t>(data, this->ElementsNum());
409 } break;
410 default:
411 oss << "Unsupported data type to print";
412 break;
413 }
414 return oss.str();
415 }
416
MallocData(const AllocatorPtr allocator)417 int Tensor::MallocData(const AllocatorPtr allocator) {
418 if (this->tensor_c_.data_ != nullptr) {
419 return RET_OK;
420 }
421 if (allocator != nullptr) {
422 allocator_ = allocator;
423 }
424 size_t element_size = DataTypeSize(static_cast<TypeId>(this->tensor_c_.data_type_));
425 if (element_size == 0) {
426 MS_LOG(ERROR) << "Unexpected data type: " << tensor_c_.data_type_;
427 return RET_ERROR;
428 }
429 auto data_size = this->Size();
430 if (data_size <= 0) {
431 MS_LOG(INFO) << "Data size=" << data_size << " bytes";
432 // expect return, currently not return for case (0,xx) shape tensor (where_fp32)
433 }
434 if (data_size > max_malloc_size_) {
435 MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes";
436 return RET_ERROR;
437 }
438 if (allocator_ == nullptr) {
439 this->tensor_c_.data_ = malloc(data_size);
440 } else {
441 #ifdef SUPPORT_NNRT
442 if (IS_NNRT_ALLOCATOR(allocator_)) {
443 this->tensor_c_.data_ = dynamic_cast<NNRTAllocator *>(allocator_.get())->MallocByDesc(data_size, this->shape(),
444 this->data_type(),
445 this->format(),
446 this->tensor_name());
447 dynamic_cast<NNRTAllocator *>(allocator_.get())->AddAllocatedLiteTensor(this->tensor_c_.data_, this);
448 } else {
449 #endif
450 this->tensor_c_.data_ = allocator_->Malloc(data_size);
451 #ifdef SUPPORT_NNRT
452 }
453 #endif
454 allocator_->SetRefCount(this->tensor_c_.data_, 1);
455 }
456 if (this->tensor_c_.data_ == nullptr) {
457 MS_LOG(ERROR) << "Malloc tensor data failed, size=" << data_size;
458 return RET_ERROR;
459 }
460 this->own_data_ = true;
461 return RET_OK;
462 }
463
FreeData()464 void Tensor::FreeData() {
465 if (IS_RUNTIME_ALLOCATOR(allocator_)) {
466 return;
467 }
468 if (this->tensor_c_.data_ != nullptr && this->own_data_) {
469 if (this->allocator_ != nullptr) {
470 if (allocator_->DecRefCount(this->tensor_c_.data_, 1) <= 0) {
471 allocator_->Free(this->tensor_c_.data_); // Due to existing various allocator, here do not set data to nullptr.
472 }
473 if (!IS_STATIC_ALLOCATOR(allocator_) || allocator_->RefCount(this->tensor_c_.data_) != 0) {
474 this->tensor_c_.data_ = nullptr;
475 }
476 } else {
477 free(this->tensor_c_.data_);
478 this->tensor_c_.data_ = nullptr;
479 }
480 } else if (this->tensor_c_.category_ == Category::VAR) {
481 if (!IS_STATIC_ALLOCATOR(allocator_) || allocator_->RefCount(this->tensor_c_.data_) != 0) {
482 if (this->init_ref_count_ == 1) {
483 this->tensor_c_.data_ = nullptr;
484 }
485 }
486 }
487 }
488
ReallocData()489 void *Tensor::ReallocData() {
490 if (this->tensor_c_.data_ != nullptr) {
491 FreeData();
492 }
493 return this->MutableData();
494 }
495
MutableData()496 void *Tensor::MutableData() {
497 if (this->tensor_c_.data_ == nullptr) {
498 auto ret = this->MallocData();
499 if (ret != 0) {
500 MS_LOG(WARNING) << "Malloc data failed";
501 }
502 }
503 Prepare();
504 return this->tensor_c_.data_;
505 }
506
DecRefCount()507 void Tensor::DecRefCount() {
508 if (this->IsGraphInput()) {
509 return;
510 }
511 int tensor_ref_count = --ref_count_;
512 if (tensor_ref_count <= 0) {
513 tensor_c_.shape_changed_ = false;
514 if (this->IsConst()) {
515 return;
516 }
517 FreeData();
518 }
519 }
520
AddQuantParam(const LiteQuantParam & quant_param)521 void Tensor::AddQuantParam(const LiteQuantParam &quant_param) { this->quant_params_.push_back(quant_param); }
522
ClearQuantParam()523 void Tensor::ClearQuantParam() {
524 this->quant_params().clear();
525 std::vector<LiteQuantParam>().swap(quant_params_);
526 }
527
quant_params() const528 std::vector<LiteQuantParam> Tensor::quant_params() const { return this->quant_params_; }
529
set_quant_params(const std::vector<LiteQuantParam> quant_params)530 void Tensor::set_quant_params(const std::vector<LiteQuantParam> quant_params) { this->quant_params_ = quant_params; }
531
quant_clusters() const532 std::vector<float> Tensor::quant_clusters() const { return this->quant_clusters_; }
533
set_quant_clusters(const std::vector<float> & clusters)534 void Tensor::set_quant_clusters(const std::vector<float> &clusters) { this->quant_clusters_ = clusters; }
535
CreateTensor(const std::string & name,TypeId type,const std::vector<int> & shape,const void * data,size_t data_len)536 Tensor *Tensor::CreateTensor(const std::string &name, TypeId type, const std::vector<int> &shape, const void *data,
537 size_t data_len) {
538 auto tensor = std::make_unique<lite::Tensor>();
539 if (tensor == nullptr) {
540 MS_LOG(ERROR) << "Failed to allocate tensor.";
541 return nullptr;
542 }
543 if (std::any_of(shape.begin(), shape.end(), [](const int &element) { return element < 0 && element != -1; })) {
544 MS_LOG(ERROR) << "Dims of tensor: " << shape << " is unsupported.";
545 return nullptr;
546 }
547 int shape_size = 0;
548 if (std::any_of(shape.begin(), shape.end(), [](const int &element) { return element == -1; })) {
549 shape_size = -1;
550 } else {
551 shape_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
552 }
553 if (shape_size == -1 && data != nullptr) {
554 MS_LOG(ERROR) << "The tensor with dynamic shape can't be const tensor.";
555 return nullptr;
556 }
557 auto data_type_size = lite::DataTypeSize(type);
558 if (data_type_size == 0) {
559 MS_LOG(ERROR) << "not support create this type: " << type;
560 return nullptr;
561 }
562 if (data == nullptr && data_len != 0) {
563 MS_LOG(ERROR) << "shape, data type and data len not match.";
564 return nullptr;
565 }
566
567 if (data != nullptr && data_len != shape_size * data_type_size) {
568 MS_LOG(ERROR) << "shape, data type and data len not match.";
569 return nullptr;
570 }
571 if (shape.size() > MAX_SHAPE_SIZE) {
572 MS_LOG(ERROR) << "The shape-size has exceeded the limit 8, now is " << shape.size();
573 return nullptr;
574 }
575 tensor->set_data(const_cast<void *>(data), false);
576 tensor->set_shape(shape);
577 tensor->set_tensor_name(name);
578 tensor->set_data_type(type);
579 tensor->set_category(data != nullptr ? (shape.empty() ? CONST_SCALAR : CONST_TENSOR) : VAR);
580 return tensor.release();
581 }
582
CreateTensorByDeepCopy(const std::string & name,TypeId type,const std::vector<int> & shape,const void * data,size_t data_len)583 Tensor *Tensor::CreateTensorByDeepCopy(const std::string &name, TypeId type, const std::vector<int> &shape,
584 const void *data, size_t data_len) {
585 auto tensor = std::make_unique<lite::Tensor>();
586 if (tensor == nullptr) {
587 MS_LOG(ERROR) << "Failed to allocate tensor.";
588 return nullptr;
589 }
590
591 auto data_type_size = lite::DataTypeSize(type);
592 if (data_type_size == 0) {
593 MS_LOG(ERROR) << "not support create this type: " << type;
594 return nullptr;
595 }
596
597 if (data_len > MAX_MALLOC_SIZE) {
598 MS_LOG(ERROR) << "data length is invalid.";
599 return nullptr;
600 } else if (data_len == 0 && data != nullptr) {
601 MS_LOG(ERROR) << "data length and data are not match.";
602 return nullptr;
603 } else if (data_len == 0 && data == nullptr) {
604 tensor->set_data(const_cast<void *>(data));
605 } else {
606 void *new_data = malloc(data_len);
607 if (new_data == nullptr) {
608 MS_LOG(ERROR) << "Failed to malloc data.";
609 return nullptr;
610 }
611 if (data != nullptr) {
612 (void)memcpy(new_data, data, data_len);
613 }
614 tensor->set_data(const_cast<void *>(new_data));
615 }
616
617 size_t shape_size = 1;
618 if (shape.empty()) {
619 shape_size = 0;
620 } else {
621 for (size_t i = 0; i < shape.size(); ++i) {
622 if (shape[i] < 0) {
623 return nullptr;
624 }
625 shape_size *= static_cast<size_t>(shape[i]);
626 }
627 }
628 if (data_len != shape_size * data_type_size) {
629 std::vector<int> truncate_shape = {static_cast<int>(data_len)};
630 tensor->set_shape(truncate_shape);
631 } else {
632 tensor->set_shape(shape);
633 }
634 tensor->set_tensor_name(name);
635 tensor->set_data_type(type);
636 tensor->set_category(data != nullptr ? (shape.empty() ? CONST_SCALAR : CONST_TENSOR) : VAR);
637 return tensor.release();
638 }
639
640 } // namespace lite
641
642 } // namespace mindspore
643