• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/tensorlist.h"
18 #include <utility>
19 #include <algorithm>
20 #include "include/ms_tensor.h"
21 #include "src/common/log_adapter.h"
22 #include "src/tensor.h"
23 #include "nnacl/op_base.h"
24 
25 namespace mindspore::lite {
TensorList(std::vector<int> shape,std::vector<int> element_shape,Category category)26 TensorList::TensorList(std::vector<int> shape, std::vector<int> element_shape, Category category)
27     : Tensor(kObjectTypeTensorType, std::move(shape), mindspore::NHWC, category),
28       element_shape_(std::move(element_shape)) {}
29 
~TensorList()30 TensorList::~TensorList() {
31   if (!this->tensors_.empty()) {
32     this->TensorList::FreeData();
33     this->FreeTensorListData();
34   }
35 }
36 
CopyTensorList(const TensorList & src,bool copy_data)37 int TensorList::CopyTensorList(const TensorList &src, bool copy_data) {
38   this->data_type_ = src.data_type_;
39   this->tensors_data_type_ = src.tensors_data_type_;
40   this->shape_ = src.shape_;
41   this->element_shape_ = src.element_shape_;
42   this->max_elements_num_ = src.max_elements_num_;
43   if (copy_data) {
44     auto ret = CopyTensorData(src);
45     if (ret != RET_OK) {
46       MS_LOG(ERROR) << "CopyTensorData error";
47       return RET_ERROR;
48     }
49   } else {
50     for (auto tensor : this->tensors()) {
51       delete tensor;
52     }
53     this->tensors_.clear();
54     // each tensor in tensors_ will share the same memory space.
55     this->tensors_ = src.tensors_;
56   }
57   return RET_OK;
58 }
59 
CopyTensorData(const TensorList & src)60 int TensorList::CopyTensorData(const TensorList &src) {
61   if (src.tensors_.empty()) {
62     return RET_OK;
63   }
64   for (auto tensor : this->tensors()) {
65     delete tensor;
66   }
67   this->tensors_.clear();
68   for (int i = 0; i < this->ElementsNum(); ++i) {
69     if (src.tensors_[i] == nullptr) {
70       MS_LOG(ERROR) << "src tensors_[" << i << "] is nullptr!";
71       return RET_ERROR;
72     }
73     auto dst_tensor = Tensor::CopyTensor(*src.tensors_[i]);
74     if (dst_tensor == nullptr) {
75       MS_LOG(ERROR) << "CopyTensorData: new tensor[" << i << "] is failed!";
76       return RET_ERROR;
77     }
78     this->tensors_.push_back(dst_tensor);
79   }
80   return RET_OK;
81 }
82 
MallocTensorListData(TypeId dtype,const std::vector<std::vector<int>> & tensor_shape)83 int TensorList::MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape) {
84   // This function will create a new tensors_
85   // Your must to set shape(param2: tensor_shape) and data_type_(tensors_data_type_ = param1: dtype) of each tensor in
86   // tensors_. After that, you need to call function:MallocData to malloc data buf of each tensor in tensors_.
87   if (!this->tensors_.empty()) {
88     // If tensors_ is not empty then clear this tensors_ and rebuild a new tensors_.
89     auto ret = FreeTensorListData();
90     if (ret != RET_OK) {
91       return RET_ERROR;
92     }
93   }
94   if (this->shape().size() != 1) {
95     MS_LOG(ERROR) << "tensorlist shape:" << this->shape().size() << " must be one-dimensional";
96     return RET_ERROR;
97   }
98   if (static_cast<size_t>(this->ElementsNum()) != tensor_shape.size()) {
99     MS_LOG(ERROR) << "tensorlist ElementsNum():" << this->ElementsNum()
100                   << " must be equal to param2:tensor_shape.size():" << tensor_shape.size();
101     return RET_ERROR;
102   }
103   this->tensors_data_type_ = dtype;
104   for (int i = 0; i < this->ElementsNum(); ++i) {
105     auto tensor_ptr = new (std::nothrow) Tensor(dtype, tensor_shape[i]);
106     if (tensor_ptr == nullptr) {
107       MS_LOG(ERROR) << "new tensors_[" << i << "] is failed!";
108       return RET_ERROR;
109     }
110     if (!this->allocator()) {
111       tensor_ptr->set_allocator(this->allocator());
112     }
113     tensor_ptr->set_init_ref_count(this->init_ref_count());
114     tensor_ptr->set_ref_count(this->ref_count());
115     this->tensors_.push_back(tensor_ptr);
116   }
117   return RET_OK;
118 }
119 
MallocData(const AllocatorPtr allocator)120 int TensorList::MallocData(const AllocatorPtr allocator) {
121   if (allocator != nullptr) {
122     allocator_ = allocator;
123   }
124   // malloc data buf of each tensor in tensors_
125   for (int i = 0; i < this->ElementsNum(); ++i) {
126     if (tensors_.empty()) {
127       return RET_OK;
128     }
129     auto tensor_ptr = this->tensors_[i];
130     if (tensor_ptr == nullptr) {
131       MS_LOG(ERROR) << "tensors_[" << i << "] is nullptr!";
132       return RET_ERROR;
133     }
134     // if data_type() is kTypeUnknown then data buf will not to be malloc
135     if (tensor_ptr->data_type() != kTypeUnknown) {
136       auto ret = tensor_ptr->MallocData(this->allocator_);
137       if (ret != RET_OK) {
138         MS_LOG(ERROR) << "tensorlist malloc tensors_[:" << i << "] is failed!";
139         return RET_ERROR;
140       }
141     }
142   }
143   return RET_OK;
144 }
145 
FreeData()146 void TensorList::FreeData() {
147   if (this->IsConst() || this->IsGraphInput()) {
148     return;
149   }
150   // free data buf of each tensor in tensors_
151   for (auto tensor : tensors_) {
152     if (tensor == nullptr) {
153       continue;
154     }
155     tensor->FreeData();
156   }
157 }
158 
FreeTensorListData()159 int TensorList::FreeTensorListData() {
160   // del each tensor in tensors_ and clear tensors_
161   if (this->tensors_.empty()) {
162     return RET_OK;
163   }
164   for (auto &tensor : this->tensors_) {
165     if (tensor != nullptr) {
166       delete tensor;
167       tensor = nullptr;
168     }
169   }
170   tensors_.clear();
171   return RET_OK;
172 }
173 
SetTensor(int index,const Tensor * src_tensor)174 int TensorList::SetTensor(int index, const Tensor *src_tensor) {
175   MS_CHECK_TRUE_MSG(src_tensor != nullptr, RET_ERROR, "src tensor cannot null");
176   // your can use this fun to modify tensor[index] value
177   if (src_tensor->data_type() != this->tensors_data_type_) {
178     MS_LOG(ERROR) << "src_tensor->data_type():" << src_tensor->data_type()
179                   << " must be equal to tensors_data_type_:" << this->tensors_data_type_;
180     return RET_ERROR;
181   }
182   if (index < 0 || index > (this->ElementsNum() - 1)) {
183     MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
184     return RET_ERROR;
185   }
186   auto dst_tensor = this->tensors_[index];
187   // free original tensor data
188   delete dst_tensor;
189   this->tensors_[index] = Tensor::CopyTensor(*src_tensor);
190   if (this->tensors_[index] == nullptr) {
191     MS_LOG(ERROR) << "SetTensor: new tensor is failed!";
192     return RET_ERROR;
193   }
194   return RET_OK;
195 }
196 
CheckTensorListParam()197 int TensorList::CheckTensorListParam() {
198   for (int i = 0; i < this->ElementsNum(); ++i) {
199     // each tensor in tensorlist must be not nullptr
200     if (this->tensors_[i] == nullptr) {
201       MS_LOG(ERROR) << "CheckTensorListParam: tensors_[" << i << "] is nullptr";
202       return RET_ERROR;
203     }
204     if (this->tensors_[i]->data_type() != this->tensors_data_type_) {
205       MS_LOG(ERROR) << "CheckTensorListParam: tensors_[i] data_type:" << this->tensors_[i]->data_type()
206                     << " is not equal to tensors_data_type_:" << this->tensors_data_type_;
207       return RET_ERROR;
208     }
209   }
210   return RET_OK;
211 }
212 
GetTensor(int index)213 Tensor *TensorList::GetTensor(int index) {
214   // return tensor[index] ptr. With this function, you can modify tensors_[index] at will.
215   if (index < 0 || index >= static_cast<int>(this->tensors_.size())) {
216     MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
217     return nullptr;
218   }
219   return this->tensors_[index];
220 }
221 
IsCompatibleShape(const std::vector<int> & shape)222 bool TensorList::IsCompatibleShape(const std::vector<int> &shape) {
223   if (this->tensors_.empty() && this->element_shape_.empty()) {
224     return true;
225   }
226   if (shape.size() != this->element_shape_.size()) {
227     return false;
228   }
229   for (size_t i = 0; i < shape.size(); ++i) {
230     if (this->element_shape_[i] >= 0 && shape[i] >= 0 && this->element_shape_[i] != shape[i]) {
231       return false;
232     }
233   }
234   return true;
235 }
236 
IsCompatibleShape(const Tensor * src)237 bool TensorList::IsCompatibleShape(const Tensor *src) {
238   MS_CHECK_TRUE_MSG(src != nullptr, false, "src tensor cannot null");
239   // shape is store in Tensor.
240   if (static_cast<size_t>(src->ElementsNum()) != this->element_shape_.size()) {
241     return false;
242   }
243   if (src->data_type() != kNumberTypeInt && src->data_type() != kNumberTypeInt32) {
244     MS_LOG(ERROR) << "src tensor data_type:" << src->data_type() << " is not int";
245     return false;
246   }
247   auto src_ptr = reinterpret_cast<int *>(src->data());
248   for (size_t i = 0; i < this->element_shape_.size(); ++i) {
249     if (this->element_shape_[i] >= 0 && src_ptr[i] >= 0 && this->element_shape_[i] != src_ptr[i]) {
250       return false;
251     }
252   }
253   return true;
254 }
255 
Decode(const int * data)256 STATUS TensorList::Decode(const int *data) {
257   if (data == nullptr) {
258     MS_LOG(ERROR) << "data is nullptr";
259     return RET_ERROR;
260   }
261   tensors_data_type_ = TypeId(data[0]);
262   for (int j = 0; j < data[1]; ++j) {
263     element_shape_.push_back(data[2 + j]);
264   }
265   int tensors_num = data[2 + data[1]];
266   if (tensors_num < 0) {
267     MS_LOG(WARNING) << "not able to create tensors, need infer shape.";
268     return RET_OK;
269   }
270 
271   if (this->ElementsNum() != tensors_num) {
272     MS_LOG(WARNING) << "Input tensorlist data is invalid: shape size(" << this->ElementsNum() << ") != tensors_num("
273                     << tensors_num << ").";
274     MS_LOG(WARNING) << "tensor name: " << this->tensor_name_;
275   }
276   tensors_.reserve(tensors_num);
277   int tensor_index = 2 + data[1] + 1;
278   for (int i = 0; i < tensors_num; i++) {
279     int tensor_dims_size = data[tensor_index++];
280     std::vector<int> shape(tensor_dims_size);
281     for (int j = 0; j < tensor_dims_size; j++) {
282       shape[j] = data[tensor_index++];
283     }
284     auto tensor = new (std::nothrow) Tensor(tensors_data_type_, shape);
285     if (tensor == nullptr) {
286       MS_LOG(ERROR) << "new Tensor failed";
287       return RET_NULL_PTR;
288     }
289     tensors_.emplace_back(tensor);
290   }
291   return RET_OK;
292 }
293 
IsConst() const294 bool TensorList::IsConst() const { return this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR; }
295 }  // namespace mindspore::lite
296