1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_TENSORLIST_H_ 18 #define MINDSPORE_LITE_SRC_TENSORLIST_H_ 19 20 #include <memory> 21 #include <vector> 22 #include "include/errorcode.h" 23 #include "nnacl/tensorlist_c.h" 24 #include "src/common/log_adapter.h" 25 #include "schema/model_generated.h" 26 #include "src/tensor.h" 27 28 namespace mindspore::lite { 29 #ifndef CONTROLFLOW_TENSORLIST_CLIP 30 /** 31 * Tensorlist is a container of vector, in which each element is a tensor object. 32 * Member objects: 33 * 1.tensors_: tensors_ is a vector, where each element is a pointer to tensor type. 34 * 2.shape_: represents the size of the tensors_ and shape_.size() must be equal to 1. 35 * 3.element_shape_: element_shape_ represents the shape of each tensor in tensors_. 36 * Some dimensions can be negative, which means that the corresponding dimensions of each tensor in tensors_ can be 37 * different. 38 * 4.data_type_: indicates that the tensorlist is a tensor of type kObjectTypeTensorType, so it can only be 39 * "kObjectTypeTensorType" 40 * 5.tensors_data_type_: data_type_ of each tensor in tensors_ 41 * Usage: 42 * std::vector<int> shape = (1, 2); // tensors_ only has two tensor 43 * std::vector<int> element_shape = {-1, 99}; 44 * // dim0 is arbitrary and dim1 is must to be 99 of each tensor.shape() in tensors_ 45 * TensorList *tl = new TensorList(shape, element_shape); 46 * std::vector<std::vector<int> > tensor_shape = std::vector<vector<int> > (2, 47 * (std::vector<int> {5, 99}, 48 * std::vector<int> {1, 99})); 49 * // tensor_shape[0] and tensor_shape[1] is not equal in dim0, but dim1 is must be equal to 99. 50 * t1->MallocTensorListData(kNumberTypeFloat, tensor_shape); 51 * t1->MallocData(); 52 * t1->... 53 * ... 54 * t1->FreeData(); 55 * t1->FreeTensorListData(); 56 * 57 * See the code for other constructors. 58 */ 59 class MS_API TensorList : public Tensor { 60 public: TensorList()61 TensorList() { tensor_list_c_ = {false, kObjectTypeTensorType, DEFAULT_FORMAT, 0, kTypeUnknown, -1, nullptr, 0, 0}; } 62 63 TensorList(std::vector<int> shape, std::vector<int> element_shape, Category category = VAR); 64 65 ~TensorList() override; 66 67 TensorList(const TensorList &other) = delete; 68 69 TensorList &operator=(const TensorList &tl) = delete; 70 set_element_shape(const std::vector<int> & shape)71 void set_element_shape(const std::vector<int> &shape) { 72 if (shape.size() > MAX_SHAPE_SIZE) { 73 FreeData(); 74 tensor_list_c_.element_shape_size_ = 0; 75 MS_LOG(WARNING) << "The shape-size has exceeded the limit 8, now is " << shape.size(); 76 return; 77 } 78 tensor_list_c_.element_shape_size_ = shape.size(); 79 for (size_t i = 0; i < shape.size(); ++i) { 80 tensor_list_c_.element_shape_[i] = shape[i]; 81 } 82 } 83 element_shape()84 std::vector<int> element_shape() const { 85 return std::vector<int>(tensor_list_c_.element_shape_, 86 tensor_list_c_.element_shape_ + tensor_list_c_.element_shape_size_); 87 } 88 set_max_elements_num(int ele_num)89 void set_max_elements_num(int ele_num) { tensor_list_c_.max_elements_num_ = ele_num; } 90 max_elements_num()91 int max_elements_num() const { return tensor_list_c_.max_elements_num_; } 92 93 static TensorList *CopyTensorList(const TensorList &src, bool copy_data = false, 94 const AllocatorPtr &allocator = nullptr); 95 96 int MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape); 97 98 int MallocData(const AllocatorPtr allocator = nullptr) override; 99 100 int FreeTensorListData(); 101 102 void FreeData() override; 103 104 int SetTensor(int index, const Tensor *src_tensor); 105 106 Tensor *GetTensor(int index); 107 set_tensors_data_type(TypeId type)108 void set_tensors_data_type(TypeId type) { tensor_list_c_.tensors_data_type_ = type; } 109 tensors_data_type()110 TypeId tensors_data_type() const { return static_cast<TypeId>(tensor_list_c_.tensors_data_type_); } 111 tensors()112 std::vector<Tensor *> tensors() { return tensors_; } 113 set_tensors(const std::vector<Tensor * > & tensors)114 void set_tensors(const std::vector<Tensor *> &tensors) { this->tensors_ = tensors; } 115 116 int CheckTensorListParam(); 117 118 bool IsCompatibleShape(const std::vector<int> &shape); 119 120 bool IsCompatibleShape(const Tensor *src); 121 122 STATUS Decode(const int *data, size_t length); 123 124 bool IsConst() const override; 125 set_init_ref_count(int ref_count)126 void set_init_ref_count(int ref_count) override { 127 this->init_ref_count_ = ref_count; 128 for (auto tensor : tensors_) { 129 if (tensor != nullptr) { 130 tensor->set_init_ref_count(ref_count); 131 } 132 } 133 } 134 set_ref_count(int ref_count)135 void set_ref_count(int ref_count) override { 136 ref_count_ = ref_count; 137 for (auto tensor : tensors_) { 138 if (tensor != nullptr) { 139 tensor->set_ref_count(ref_count); 140 } 141 } 142 } 143 ResetRefCount()144 void ResetRefCount() override { 145 set_ref_count(this->init_ref_count_); 146 for (auto tensor : tensors_) { 147 if (tensor != nullptr) { 148 tensor->set_ref_count(this->init_ref_count_); 149 } 150 } 151 } 152 IncRefCount()153 void IncRefCount() override { 154 ++ref_count_; 155 for (auto tensor : tensors_) { 156 if (tensor != nullptr) { 157 tensor->IncRefCount(); 158 } 159 } 160 } 161 DecRefCount()162 void DecRefCount() override { 163 if (this->IsConst() || this->IsGraphInput()) { 164 return; 165 } 166 --ref_count_; 167 for (auto tensor : tensors_) { 168 if (tensor != nullptr) { 169 tensor->DecRefCount(); 170 } 171 } 172 } 173 set_allocator(AllocatorPtr allocator)174 void set_allocator(AllocatorPtr allocator) override { 175 allocator_ = allocator; 176 for (auto tensor : tensors_) { 177 if (tensor != nullptr) { 178 tensor->set_allocator(allocator); 179 } 180 } 181 } 182 set_own_data(bool own_data)183 void set_own_data(bool own_data) override { 184 this->own_data_ = own_data; 185 for (auto tensor : tensors_) { 186 if (tensor != nullptr) { 187 tensor->set_own_data(own_data); 188 } 189 } 190 } 191 ConvertToTensorListC()192 TensorListC *ConvertToTensorListC() { 193 tensor_list_c_.format_ = tensor_c_.format_; 194 tensor_list_c_.shape_value_ = tensor_c_.shape_size_ == 0 ? 0 : tensor_c_.shape_[0]; 195 tensor_list_c_.element_num_ = tensor_c_.shape_size_ == 0 ? 0 : tensors_.size(); 196 tensor_list_c_.tensors_ = nullptr; 197 return &tensor_list_c_; 198 } 199 200 protected: 201 // The following functions must be masked. data()202 void *data() const override { return nullptr; } MutableData()203 void *MutableData() override { return nullptr; } Size()204 size_t Size() const override { return 0; } 205 TensorListC tensor_list_c_; 206 std::vector<Tensor *> tensors_{}; 207 }; 208 209 #else 210 211 using TensorList = void; 212 213 #endif 214 } // namespace mindspore::lite 215 #endif // MINDSPORE_LITE_SRC_TENSORLIST_H_ 216