1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/tensorlist.h"
18 #include <utility>
19 #include <algorithm>
20 #include "include/ms_tensor.h"
21 #include "src/common/log_adapter.h"
22 #include "src/tensor.h"
23 #include "nnacl/op_base.h"
24
25 namespace mindspore::lite {
TensorList(std::vector<int> shape,std::vector<int> element_shape,Category category)26 TensorList::TensorList(std::vector<int> shape, std::vector<int> element_shape, Category category)
27 : Tensor(kObjectTypeTensorType, std::move(shape), mindspore::NHWC, category),
28 element_shape_(std::move(element_shape)) {}
29
~TensorList()30 TensorList::~TensorList() {
31 if (!this->tensors_.empty()) {
32 this->TensorList::FreeData();
33 this->FreeTensorListData();
34 }
35 }
36
CopyTensorList(const TensorList & src,bool copy_data)37 int TensorList::CopyTensorList(const TensorList &src, bool copy_data) {
38 this->data_type_ = src.data_type_;
39 this->tensors_data_type_ = src.tensors_data_type_;
40 this->shape_ = src.shape_;
41 this->element_shape_ = src.element_shape_;
42 this->max_elements_num_ = src.max_elements_num_;
43 if (copy_data) {
44 auto ret = CopyTensorData(src);
45 if (ret != RET_OK) {
46 MS_LOG(ERROR) << "CopyTensorData error";
47 return RET_ERROR;
48 }
49 } else {
50 for (auto tensor : this->tensors()) {
51 delete tensor;
52 }
53 this->tensors_.clear();
54 // each tensor in tensors_ will share the same memory space.
55 this->tensors_ = src.tensors_;
56 }
57 return RET_OK;
58 }
59
CopyTensorData(const TensorList & src)60 int TensorList::CopyTensorData(const TensorList &src) {
61 if (src.tensors_.empty()) {
62 return RET_OK;
63 }
64 for (auto tensor : this->tensors()) {
65 delete tensor;
66 }
67 this->tensors_.clear();
68 for (int i = 0; i < this->ElementsNum(); ++i) {
69 if (src.tensors_[i] == nullptr) {
70 MS_LOG(ERROR) << "src tensors_[" << i << "] is nullptr!";
71 return RET_ERROR;
72 }
73 auto dst_tensor = Tensor::CopyTensor(*src.tensors_[i]);
74 if (dst_tensor == nullptr) {
75 MS_LOG(ERROR) << "CopyTensorData: new tensor[" << i << "] is failed!";
76 return RET_ERROR;
77 }
78 this->tensors_.push_back(dst_tensor);
79 }
80 return RET_OK;
81 }
82
MallocTensorListData(TypeId dtype,const std::vector<std::vector<int>> & tensor_shape)83 int TensorList::MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape) {
84 // This function will create a new tensors_
85 // Your must to set shape(param2: tensor_shape) and data_type_(tensors_data_type_ = param1: dtype) of each tensor in
86 // tensors_. After that, you need to call function:MallocData to malloc data buf of each tensor in tensors_.
87 if (!this->tensors_.empty()) {
88 // If tensors_ is not empty then clear this tensors_ and rebuild a new tensors_.
89 auto ret = FreeTensorListData();
90 if (ret != RET_OK) {
91 return RET_ERROR;
92 }
93 }
94 if (this->shape().size() != 1) {
95 MS_LOG(ERROR) << "tensorlist shape:" << this->shape().size() << " must be one-dimensional";
96 return RET_ERROR;
97 }
98 if (static_cast<size_t>(this->ElementsNum()) != tensor_shape.size()) {
99 MS_LOG(ERROR) << "tensorlist ElementsNum():" << this->ElementsNum()
100 << " must be equal to param2:tensor_shape.size():" << tensor_shape.size();
101 return RET_ERROR;
102 }
103 this->tensors_data_type_ = dtype;
104 for (int i = 0; i < this->ElementsNum(); ++i) {
105 auto tensor_ptr = new (std::nothrow) Tensor(dtype, tensor_shape[i]);
106 if (tensor_ptr == nullptr) {
107 MS_LOG(ERROR) << "new tensors_[" << i << "] is failed!";
108 return RET_ERROR;
109 }
110 if (!this->allocator()) {
111 tensor_ptr->set_allocator(this->allocator());
112 }
113 tensor_ptr->set_init_ref_count(this->init_ref_count());
114 tensor_ptr->set_ref_count(this->ref_count());
115 this->tensors_.push_back(tensor_ptr);
116 }
117 return RET_OK;
118 }
119
MallocData(const AllocatorPtr allocator)120 int TensorList::MallocData(const AllocatorPtr allocator) {
121 if (allocator != nullptr) {
122 allocator_ = allocator;
123 }
124 // malloc data buf of each tensor in tensors_
125 for (int i = 0; i < this->ElementsNum(); ++i) {
126 if (tensors_.empty()) {
127 return RET_OK;
128 }
129 auto tensor_ptr = this->tensors_[i];
130 if (tensor_ptr == nullptr) {
131 MS_LOG(ERROR) << "tensors_[" << i << "] is nullptr!";
132 return RET_ERROR;
133 }
134 // if data_type() is kTypeUnknown then data buf will not to be malloc
135 if (tensor_ptr->data_type() != kTypeUnknown) {
136 auto ret = tensor_ptr->MallocData(this->allocator_);
137 if (ret != RET_OK) {
138 MS_LOG(ERROR) << "tensorlist malloc tensors_[:" << i << "] is failed!";
139 return RET_ERROR;
140 }
141 }
142 }
143 return RET_OK;
144 }
145
FreeData()146 void TensorList::FreeData() {
147 if (this->IsConst() || this->IsGraphInput()) {
148 return;
149 }
150 // free data buf of each tensor in tensors_
151 for (auto tensor : tensors_) {
152 if (tensor == nullptr) {
153 continue;
154 }
155 tensor->FreeData();
156 }
157 }
158
FreeTensorListData()159 int TensorList::FreeTensorListData() {
160 // del each tensor in tensors_ and clear tensors_
161 if (this->tensors_.empty()) {
162 return RET_OK;
163 }
164 for (auto &tensor : this->tensors_) {
165 if (tensor != nullptr) {
166 delete tensor;
167 tensor = nullptr;
168 }
169 }
170 tensors_.clear();
171 return RET_OK;
172 }
173
SetTensor(int index,const Tensor * src_tensor)174 int TensorList::SetTensor(int index, const Tensor *src_tensor) {
175 MS_CHECK_TRUE_MSG(src_tensor != nullptr, RET_ERROR, "src tensor cannot null");
176 // your can use this fun to modify tensor[index] value
177 if (src_tensor->data_type() != this->tensors_data_type_) {
178 MS_LOG(ERROR) << "src_tensor->data_type():" << src_tensor->data_type()
179 << " must be equal to tensors_data_type_:" << this->tensors_data_type_;
180 return RET_ERROR;
181 }
182 if (index < 0 || index > (this->ElementsNum() - 1)) {
183 MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
184 return RET_ERROR;
185 }
186 auto dst_tensor = this->tensors_[index];
187 // free original tensor data
188 delete dst_tensor;
189 this->tensors_[index] = Tensor::CopyTensor(*src_tensor);
190 if (this->tensors_[index] == nullptr) {
191 MS_LOG(ERROR) << "SetTensor: new tensor is failed!";
192 return RET_ERROR;
193 }
194 return RET_OK;
195 }
196
CheckTensorListParam()197 int TensorList::CheckTensorListParam() {
198 for (int i = 0; i < this->ElementsNum(); ++i) {
199 // each tensor in tensorlist must be not nullptr
200 if (this->tensors_[i] == nullptr) {
201 MS_LOG(ERROR) << "CheckTensorListParam: tensors_[" << i << "] is nullptr";
202 return RET_ERROR;
203 }
204 if (this->tensors_[i]->data_type() != this->tensors_data_type_) {
205 MS_LOG(ERROR) << "CheckTensorListParam: tensors_[i] data_type:" << this->tensors_[i]->data_type()
206 << " is not equal to tensors_data_type_:" << this->tensors_data_type_;
207 return RET_ERROR;
208 }
209 }
210 return RET_OK;
211 }
212
GetTensor(int index)213 Tensor *TensorList::GetTensor(int index) {
214 // return tensor[index] ptr. With this function, you can modify tensors_[index] at will.
215 if (index < 0 || index >= static_cast<int>(this->tensors_.size())) {
216 MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
217 return nullptr;
218 }
219 return this->tensors_[index];
220 }
221
IsCompatibleShape(const std::vector<int> & shape)222 bool TensorList::IsCompatibleShape(const std::vector<int> &shape) {
223 if (this->tensors_.empty() && this->element_shape_.empty()) {
224 return true;
225 }
226 if (shape.size() != this->element_shape_.size()) {
227 return false;
228 }
229 for (size_t i = 0; i < shape.size(); ++i) {
230 if (this->element_shape_[i] >= 0 && shape[i] >= 0 && this->element_shape_[i] != shape[i]) {
231 return false;
232 }
233 }
234 return true;
235 }
236
IsCompatibleShape(const Tensor * src)237 bool TensorList::IsCompatibleShape(const Tensor *src) {
238 MS_CHECK_TRUE_MSG(src != nullptr, false, "src tensor cannot null");
239 // shape is store in Tensor.
240 if (static_cast<size_t>(src->ElementsNum()) != this->element_shape_.size()) {
241 return false;
242 }
243 if (src->data_type() != kNumberTypeInt && src->data_type() != kNumberTypeInt32) {
244 MS_LOG(ERROR) << "src tensor data_type:" << src->data_type() << " is not int";
245 return false;
246 }
247 auto src_ptr = reinterpret_cast<int *>(src->data());
248 for (size_t i = 0; i < this->element_shape_.size(); ++i) {
249 if (this->element_shape_[i] >= 0 && src_ptr[i] >= 0 && this->element_shape_[i] != src_ptr[i]) {
250 return false;
251 }
252 }
253 return true;
254 }
255
Decode(const int * data)256 STATUS TensorList::Decode(const int *data) {
257 if (data == nullptr) {
258 MS_LOG(ERROR) << "data is nullptr";
259 return RET_ERROR;
260 }
261 tensors_data_type_ = TypeId(data[0]);
262 for (int j = 0; j < data[1]; ++j) {
263 element_shape_.push_back(data[2 + j]);
264 }
265 int tensors_num = data[2 + data[1]];
266 if (tensors_num < 0) {
267 MS_LOG(WARNING) << "not able to create tensors, need infer shape.";
268 return RET_OK;
269 }
270
271 if (this->ElementsNum() != tensors_num) {
272 MS_LOG(WARNING) << "Input tensorlist data is invalid: shape size(" << this->ElementsNum() << ") != tensors_num("
273 << tensors_num << ").";
274 MS_LOG(WARNING) << "tensor name: " << this->tensor_name_;
275 }
276 tensors_.reserve(tensors_num);
277 int tensor_index = 2 + data[1] + 1;
278 for (int i = 0; i < tensors_num; i++) {
279 int tensor_dims_size = data[tensor_index++];
280 std::vector<int> shape(tensor_dims_size);
281 for (int j = 0; j < tensor_dims_size; j++) {
282 shape[j] = data[tensor_index++];
283 }
284 auto tensor = new (std::nothrow) Tensor(tensors_data_type_, shape);
285 if (tensor == nullptr) {
286 MS_LOG(ERROR) << "new Tensor failed";
287 return RET_NULL_PTR;
288 }
289 tensors_.emplace_back(tensor);
290 }
291 return RET_OK;
292 }
293
IsConst() const294 bool TensorList::IsConst() const { return this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR; }
295 } // namespace mindspore::lite
296