1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_TENSORLIST_H_ 18 #define MINDSPORE_LITE_SRC_TENSORLIST_H_ 19 20 #include <memory> 21 #include <vector> 22 #include "include/ms_tensor.h" 23 #include "include/errorcode.h" 24 #include "src/common/log_adapter.h" 25 #include "schema/model_generated.h" 26 #include "src/tensor.h" 27 28 namespace mindspore::lite { 29 /** 30 * Tensorlist is a container of vector, in which each element is a tensor object. 31 * Member objects: 32 * 1.tensors_: tensors_ is a vector, where each element is a pointer to tensor type. 33 * 2.shape_: represents the size of the tensors_ and shape_.size() must be equal to 1. 34 * 3.element_shape_: element_shape_ represents the shape of each tensor in tensors_. 35 * Some dimensions can be negative, which means that the corresponding dimensions of each tensor in tensors_ can be 36 * different. 37 * 4.data_type_: indicates that the tensorlist is a tensor of type kObjectTypeTensorType, so it can only be 38 * "kObjectTypeTensorType" 39 * 5.tensors_data_type_: data_type_ of each tensor in tensors_ 40 * Usage: 41 * std::vector<int> shape = (1, 2); // tensors_ only has two tensor 42 * std::vector<int> element_shape = {-1, 99}; 43 * // dim0 is arbitrary and dim1 is must to be 99 of each tensor.shape() in tensors_ 44 * TensorList *tl = new TensorList(shape, element_shape); 45 * std::vector<std::vector<int> > tensor_shape = std::vector<vector<int> > (2, 46 * (std::vector<int> {5, 99}, 47 * std::vector<int> {1, 99})); 48 * // tensor_shape[0] and tensor_shape[1] is not equal in dim0, but dim1 is must be equal to 99. 49 * t1->MallocTensorListData(kNumberTypeFloat, tensor_shape); 50 * t1->MallocData(); 51 * t1->... 52 * ... 53 * t1->FreeData(); 54 * t1->FreeTensorListData(); 55 * 56 * See the code for other constructors. 57 */ 58 class TensorList : public Tensor { 59 public: 60 TensorList() = default; 61 62 TensorList(std::vector<int> shape, std::vector<int> element_shape, Category category = VAR); 63 64 ~TensorList() override; 65 66 TensorList(const TensorList &other) = delete; 67 68 TensorList &operator=(const TensorList &tl) = delete; 69 set_element_shape(const std::vector<int> & shape)70 void set_element_shape(const std::vector<int> &shape) { element_shape_ = shape; } 71 element_shape()72 std::vector<int> &element_shape() { return element_shape_; } 73 set_max_elements_num(int ele_num)74 void set_max_elements_num(int ele_num) { max_elements_num_ = ele_num; } 75 max_elements_num()76 int max_elements_num() const { return max_elements_num_; } 77 78 int MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape); 79 80 int MallocData(const AllocatorPtr allocator = nullptr) override; 81 82 int FreeTensorListData(); 83 84 void FreeData() override; 85 86 int CopyTensorList(const TensorList &src, bool copy_data); 87 88 int CopyTensorData(const TensorList &src); 89 90 int SetTensor(int index, const Tensor *src_tensor); 91 92 Tensor *GetTensor(int index); 93 set_tensors_data_type(TypeId type)94 void set_tensors_data_type(TypeId type) { tensors_data_type_ = type; } 95 tensors_data_type()96 TypeId tensors_data_type() const { return tensors_data_type_; } 97 tensors()98 std::vector<Tensor *> &tensors() { return tensors_; } 99 set_tensors(const std::vector<Tensor * > & tensors)100 void set_tensors(const std::vector<Tensor *> &tensors) { this->tensors_ = tensors; } 101 102 int CheckTensorListParam(); 103 104 bool IsCompatibleShape(const std::vector<int> &shape); 105 106 bool IsCompatibleShape(const Tensor *src); 107 108 STATUS Decode(const int *data); 109 110 bool IsConst() const override; 111 set_ref_count(int ref_count)112 void set_ref_count(int ref_count) override { 113 ref_count_ = ref_count; 114 for (auto tensor : tensors_) { 115 if (tensor != nullptr) { 116 tensor->set_ref_count(ref_count); 117 } 118 } 119 } 120 ResetRefCount()121 void ResetRefCount() override { 122 set_ref_count(this->init_ref_count_); 123 for (auto tensor : tensors_) { 124 if (tensor != nullptr) { 125 tensor->set_ref_count(this->init_ref_count_); 126 } 127 } 128 } 129 IncRefCount()130 void IncRefCount() override { 131 ++ref_count_; 132 for (auto tensor : tensors_) { 133 if (tensor != nullptr) { 134 tensor->IncRefCount(); 135 } 136 } 137 } 138 DecRefCount()139 void DecRefCount() override { 140 if (this->IsConst() || this->IsGraphInput()) { 141 return; 142 } 143 --ref_count_; 144 for (auto tensor : tensors_) { 145 if (tensor != nullptr) { 146 tensor->DecRefCount(); 147 } 148 } 149 } 150 set_allocator(AllocatorPtr allocator)151 void set_allocator(AllocatorPtr allocator) override { 152 allocator_ = allocator; 153 for (auto tensor : tensors_) { 154 if (tensor != nullptr) { 155 tensor->set_allocator(allocator); 156 } 157 } 158 } 159 set_own_data(bool own_data)160 void set_own_data(bool own_data) override { 161 this->own_data_ = own_data; 162 for (auto tensor : tensors_) { 163 if (tensor != nullptr) { 164 tensor->set_own_data(own_data); 165 } 166 } 167 } 168 169 protected: 170 // The following functions must be masked. data()171 void *data() const override { return nullptr; } MutableData()172 void *MutableData() override { return nullptr; } Size()173 size_t Size() const override { return 0; } 174 std::vector<Tensor *> tensors_{}; 175 TypeId tensors_data_type_ = kTypeUnknown; 176 std::vector<int> element_shape_{}; 177 int max_elements_num_ = -1; 178 }; 179 } // namespace mindspore::lite 180 #endif // MINDSPORE_LITE_SRC_TENSORLIST_H_ 181