• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/tensor.h"
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #include <algorithm>
22 #include "securec/include/securec.h"
23 #include "include/errorcode.h"
24 
25 namespace mindspore {
26 namespace lite {
27 namespace {
28 static const size_t max_malloc_size_ = GetMaxMallocSize();
29 }  // namespace
Tensor(const TypeId data_type,std::vector<int> shape,const mindspore::Format & format,Category category)30 Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const mindspore::Format &format, Category category)
31     : data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {}
32 
CopyTensorData(const Tensor & src_tensor,Tensor * dst_tensor)33 int Tensor::CopyTensorData(const Tensor &src_tensor, Tensor *dst_tensor) {
34   if (dst_tensor == nullptr) {
35     MS_LOG(ERROR) << "dst_tensor is nullptr";
36     return RET_PARAM_INVALID;
37   }
38   if (src_tensor.data_ == nullptr) {
39     MS_LOG(INFO) << "data of src tensor is nullptr";
40     return RET_OK;
41   }
42   size_t data_size = dst_tensor->Size();
43   if (data_size != src_tensor.Size()) {
44     MS_LOG(ERROR) << "Size of dst tensor is not compatible with src tensor";
45     return RET_ERROR;
46   }
47   if (dst_tensor->MallocData() != RET_OK) {
48     MS_LOG(ERROR) << "Malloc memory failed";
49     return RET_ERROR;
50   }
51   dst_tensor->ResetRefCount();
52   memcpy(dst_tensor->data_, src_tensor.data_, data_size);
53   return RET_OK;
54 }
55 
CopyTensor(const Tensor & src_tensor,bool copy_data,AllocatorPtr allocator)56 Tensor *Tensor::CopyTensor(const Tensor &src_tensor, bool copy_data, AllocatorPtr allocator) {
57   auto *result = new (std::nothrow) Tensor;
58   if (result == nullptr) {
59     MS_LOG(ERROR) << "New tensor failed";
60     return nullptr;
61   }
62   result->data_type_ = src_tensor.data_type_;
63   result->shape_ = src_tensor.shape_;
64   result->category_ = src_tensor.category_;
65   result->format_ = src_tensor.format_;
66   result->set_allocator(allocator);
67   result->set_tensor_name(src_tensor.tensor_name() + "_duplicate");
68   if (copy_data) {
69     auto ret = CopyTensorData(src_tensor, result);
70     if (ret != RET_OK) {
71       MS_LOG(ERROR) << "CopyTensorData error";
72       delete result;
73       return nullptr;
74     }
75     result->own_data_ = src_tensor.own_data_;
76   }
77 
78   for (LiteQuantParam quant : src_tensor.quant_params()) {
79     result->AddQuantParam(quant);
80   }
81 
82   return result;
83 }
84 
~Tensor()85 Tensor::~Tensor() {
86   FreeData();
87   this->data_ = nullptr;
88 }
89 
operator ==(const Tensor & tensor)90 bool Tensor::operator==(const Tensor &tensor) {
91   return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_;
92 }
93 
Batch() const94 int32_t Tensor::Batch() const {
95   if (this->shape_.size() != 4 && this->shape_.size() != 2) {
96     MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
97     return RET_ERROR;
98   }
99   switch (this->format_) {
100     case mindspore::NHWC:
101     case mindspore::NHWC4:
102     case mindspore::NCHW:
103     case mindspore::NC4HW4:
104     case mindspore::KCHW:
105     case mindspore::KHWC:
106     case mindspore::NC:
107     case mindspore::NC4:
108       return this->shape_[0];
109     case mindspore::HWCK:
110     case mindspore::CHWK:
111       return this->shape_[3];
112     case mindspore::HWKC:
113       return this->shape_[2];
114     case mindspore::CKHW:
115       return this->shape_[1];
116     default:
117       MS_LOG(ERROR) << "Unsupported format: " << EnumNameFormat(static_cast<schema::Format>(this->format_));
118       return RET_ERROR;
119   }
120 }
121 
Channel() const122 int32_t Tensor::Channel() const {
123   if (this->shape_.size() != 4 && this->shape_.size() != 2) {
124     MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
125     return RET_ERROR;
126   }
127   switch (this->format_) {
128     case mindspore::NCHW:
129     case mindspore::KCHW:
130     case mindspore::NC:
131     case mindspore::NC4:
132       return this->shape_[1];
133     case mindspore::HWCK:
134       return this->shape_[2];
135     case mindspore::HWKC:
136     case mindspore::NHWC:
137     case mindspore::NHWC4:
138     case mindspore::NC4HW4:
139     case mindspore::KHWC:
140       return this->shape_[3];
141     case mindspore::CKHW:
142     case mindspore::CHWK:
143       return this->shape_[0];
144     default:
145       return RET_ERROR;
146   }
147 }
148 
Height() const149 int32_t Tensor::Height() const {
150   if (this->shape_.size() != 4 && this->shape_.size() != 2) {
151     MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
152     return RET_ERROR;
153   }
154   switch (this->format_) {
155     case mindspore::NCHW:
156     case mindspore::KCHW:
157     case mindspore::CKHW:
158       return this->shape_[2];
159     case mindspore::NHWC:
160     case mindspore::NHWC4:
161     case mindspore::NC4HW4:
162     case mindspore::KHWC:
163     case mindspore::CHWK:
164       return this->shape_[1];
165     case mindspore::HWCK:
166     case mindspore::HWKC:
167     case mindspore::HW:
168     case mindspore::HW4:
169       return this->shape_[0];
170     default:
171       MS_LOG(ERROR) << "Unsupported format: " << EnumNameFormat(static_cast<schema::Format>(this->format_));
172       return RET_ERROR;
173   }
174 }
175 
Width() const176 int32_t Tensor::Width() const {
177   if (this->shape_.size() != 4 && this->shape_.size() != 2) {
178     MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
179     return RET_ERROR;
180   }
181   switch (this->format_) {
182     case mindspore::NCHW:
183     case mindspore::KCHW:
184     case mindspore::CKHW:
185       return this->shape_[3];
186     case mindspore::KHWC:
187     case mindspore::NHWC:
188     case mindspore::NHWC4:
189     case mindspore::NC4HW4:
190     case mindspore::CHWK:
191       return this->shape_[2];
192     case mindspore::HWCK:
193     case mindspore::HWKC:
194     case mindspore::HW:
195     case mindspore::HW4:
196       return this->shape_[1];
197     default:
198       return RET_ERROR;
199   }
200 }
201 
Size() const202 size_t Tensor::Size() const {
203   size_t element_size = DataTypeSize(this->data_type_);
204   auto element_num = (format_ == mindspore::NC4HW4 || format_ == mindspore::NHWC4) ? ElementsC4Num() : ElementsNum();
205   if (element_num < 0) {
206     MS_LOG(INFO) << "Element number of tensor should large than 0 : " << element_num;
207     return 0;
208   }
209   return element_size * element_num;
210 }
211 
ElementsNum() const212 int Tensor::ElementsNum() const {
213   if (this->category_ == CONST_SCALAR) {
214     return 1;
215   }
216   auto num = std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int64_t>());
217   if (num > (int64_t)INT32_MAX) {
218     MS_LOG(ERROR) << "Element number of tensor should be smaller than int32_max: " << num << " return INT32_MAX";
219     return INT32_MAX;
220   }
221   return (int32_t)num;
222 }
223 
ElementsC4Num() const224 int32_t Tensor::ElementsC4Num() const {
225   if (this->category_ == CONST_SCALAR) {
226     return 1;
227   }
228   int32_t result = 1;
229   if (this->shape_.size() == 4) {
230     result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4);
231   } else if (this->shape_.size() == 2) {
232     result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4);
233   }
234   return result;
235 }
236 
DimensionSize(const size_t index) const237 int Tensor::DimensionSize(const size_t index) const {
238   int dim_size = -1;
239   if (index < shape_.size()) {
240     dim_size = shape_[index];
241   } else {
242     MS_LOG(ERROR) << "Dimension index is wrong: " << index;
243   }
244   return dim_size;
245 }
246 
ToString() const247 std::string Tensor::ToString() const {
248   std::ostringstream oss;
249   oss << "schema::Format: " << EnumNameFormat(static_cast<schema::Format>(this->format_));
250   oss << " DataType: " << this->data_type_;
251   oss << " Category: " << this->category_;
252   oss << " Shape:";
253   for (auto &dim : this->shape()) {
254     oss << " " << dim;
255   }
256   oss << std::endl << "Data:";
257   switch (this->data_type_) {
258     case kNumberTypeFloat32: {
259       oss << DataToString<float>(data_, this->ElementsNum());
260     } break;
261     case kNumberTypeFloat16: {
262       oss << DataToString<int16_t>(data_, this->ElementsNum());
263     } break;
264     case kNumberTypeInt32: {
265       oss << DataToString<int32_t>(data_, this->ElementsNum());
266     } break;
267     case kNumberTypeInt16: {
268       oss << DataToString<int16_t>(data_, this->ElementsNum());
269     } break;
270     case kNumberTypeInt8: {
271       oss << DataToString<int8_t>(data_, this->ElementsNum());
272     } break;
273     default:
274       oss << "Unsupported data type to print";
275       break;
276   }
277   return oss.str();
278 }
279 
MallocData(const AllocatorPtr allocator)280 int Tensor::MallocData(const AllocatorPtr allocator) {
281   if (this->data_ != nullptr) {
282     return RET_OK;
283   }
284   if (allocator != nullptr) {
285     allocator_ = allocator;
286   }
287   size_t element_size = DataTypeSize(this->data_type_);
288   if (element_size == 0) {
289     MS_LOG(ERROR) << "Unexpected data type: " << data_type_;
290     return RET_ERROR;
291   }
292   auto data_size = this->Size();
293   if (data_size <= 0) {
294     MS_LOG(INFO) << "Data size=" << data_size << " bytes";
295     // expect return, currently not return for case (0,xx) shape tensor (where_fp32)
296   }
297   if (data_size > max_malloc_size_) {
298     MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes";
299     return RET_ERROR;
300   }
301   if (allocator_ == nullptr) {
302     this->data_ = malloc(data_size);
303   } else {
304     this->data_ = allocator_->Malloc(data_size);
305     allocator_->SetRefCount(this->data_, 1);
306   }
307   if (this->data_ == nullptr) {
308     MS_LOG(ERROR) << "Malloc tensor data failed, size=" << data_size;
309     return RET_ERROR;
310   }
311   this->own_data_ = true;
312   return RET_OK;
313 }
314 
FreeData()315 void Tensor::FreeData() {
316   if (IS_RUNTIME_ALLOCATOR(allocator_)) {
317     return;
318   }
319   if (this->data_ != nullptr && this->own_data_) {
320     if (this->allocator_ != nullptr) {
321       if (allocator_->DecRefCount(this->data_, 1) <= 0) {
322         allocator_->Free(this->data_);  // Due to existing various allocator, here do not set data to nullptr.
323       }
324       if (!IS_STATIC_ALLOCATOR(allocator_) || allocator_->RefCount(this->data_) != 0) {
325         this->data_ = nullptr;
326       }
327     } else {
328       free(this->data_);
329       this->data_ = nullptr;
330     }
331   }
332 }
333 
ReallocData()334 void *Tensor::ReallocData() {
335   if (this->data_ != nullptr) {
336     FreeData();
337   }
338   return this->MutableData();
339 }
340 
MutableData()341 void *Tensor::MutableData() {
342   if (this->data_ == nullptr) {
343     auto ret = this->MallocData();
344     if (ret != 0) {
345       MS_LOG(WARNING) << "Malloc data failed";
346     }
347   }
348   Prepare();
349   return this->data_;
350 }
351 
IncRefCount()352 void Tensor::IncRefCount() {
353   ref_count_++;
354 }
355 
DecRefCount()356 void Tensor::DecRefCount() {
357  if (this->IsConst() || this->IsGraphInput()) {
358     return;
359   }
360   int tensor_ref_count = --ref_count_;
361   if (tensor_ref_count <= 0) {
362     FreeData();
363   }
364 }
365 
AddQuantParam(const LiteQuantParam & quant_param)366 void Tensor::AddQuantParam(const LiteQuantParam &quant_param) { this->quant_params_.push_back(quant_param); }
367 
quant_params() const368 std::vector<LiteQuantParam> Tensor::quant_params() const { return this->quant_params_; }
369 
set_quant_params(const std::vector<LiteQuantParam> quant_params)370 void Tensor::set_quant_params(const std::vector<LiteQuantParam> quant_params) { this->quant_params_ = quant_params; }
371 
quant_clusters() const372 std::vector<float> Tensor::quant_clusters() const { return this->quant_clusters_; }
373 
set_quant_clusters(const std::vector<float> & clusters)374 void Tensor::set_quant_clusters(const std::vector<float> &clusters) { this->quant_clusters_ = clusters; }
375 
TensorVectorCast(const std::vector<Tensor * > & src)376 std::vector<tensor::MSTensor *> TensorVectorCast(const std::vector<Tensor *> &src) {
377   std::vector<tensor::MSTensor *> target(src.size());
378   std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return static_cast<tensor::MSTensor *>(t); });
379   return target;
380 }
381 
382 }  // namespace lite
383 
384 }  // namespace mindspore
385