• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_TENSORLIST_H_
18 #define MINDSPORE_LITE_SRC_TENSORLIST_H_
19 
20 #include <memory>
21 #include <vector>
22 #include "include/errorcode.h"
23 #include "nnacl/tensorlist_c.h"
24 #include "src/common/log_adapter.h"
25 #include "schema/model_generated.h"
26 #include "src/tensor.h"
27 
28 namespace mindspore::lite {
29 #ifndef CONTROLFLOW_TENSORLIST_CLIP
30 /**
31  * Tensorlist is a container of vector, in which each element is a tensor object.
32  * Member objects:
33  *  1.tensors_: tensors_ is a vector, where each element is a pointer to tensor type.
34  *  2.shape_: represents the size of the tensors_ and shape_.size() must be equal to 1.
35  *  3.element_shape_: element_shape_ represents the shape of each tensor in tensors_.
36  *    Some dimensions can be negative, which means that the corresponding dimensions of each tensor in tensors_ can be
37  *    different.
38  *  4.data_type_: indicates that the tensorlist is a tensor of type kObjectTypeTensorType, so it can only be
39  *    "kObjectTypeTensorType"
40  *  5.tensors_data_type_: data_type_ of each tensor in tensors_
41  * Usage:
42  *  std::vector<int> shape = (1, 2);  // tensors_ only has two tensor
43  *  std::vector<int> element_shape = {-1, 99};
44  *  // dim0 is arbitrary and dim1 is must to be 99 of each tensor.shape() in tensors_
45  *  TensorList *tl = new TensorList(shape, element_shape);
46  *  std::vector<std::vector<int> > tensor_shape = std::vector<vector<int> > (2,
47  *                                                                          (std::vector<int> {5, 99},
48  *                                                                           std::vector<int> {1, 99}));
49  *  // tensor_shape[0] and tensor_shape[1] is not equal in dim0, but dim1 is must be equal to 99.
50  *  t1->MallocTensorListData(kNumberTypeFloat, tensor_shape);
51  *  t1->MallocData();
52  *  t1->...
53  *  ...
54  *  t1->FreeData();
55  *  t1->FreeTensorListData();
56  *
57  *  See the code for other constructors.
58  */
59 class MS_API TensorList : public Tensor {
60  public:
TensorList()61   TensorList() { tensor_list_c_ = {false, kObjectTypeTensorType, DEFAULT_FORMAT, 0, kTypeUnknown, -1, nullptr, 0, 0}; }
62 
63   TensorList(std::vector<int> shape, std::vector<int> element_shape, Category category = VAR);
64 
65   ~TensorList() override;
66 
67   TensorList(const TensorList &other) = delete;
68 
69   TensorList &operator=(const TensorList &tl) = delete;
70 
set_element_shape(const std::vector<int> & shape)71   void set_element_shape(const std::vector<int> &shape) {
72     if (shape.size() > MAX_SHAPE_SIZE) {
73       FreeData();
74       tensor_list_c_.element_shape_size_ = 0;
75       MS_LOG(WARNING) << "The shape-size has exceeded the limit 8, now is " << shape.size();
76       return;
77     }
78     tensor_list_c_.element_shape_size_ = shape.size();
79     for (size_t i = 0; i < shape.size(); ++i) {
80       tensor_list_c_.element_shape_[i] = shape[i];
81     }
82   }
83 
element_shape()84   std::vector<int> element_shape() const {
85     return std::vector<int>(tensor_list_c_.element_shape_,
86                             tensor_list_c_.element_shape_ + tensor_list_c_.element_shape_size_);
87   }
88 
set_max_elements_num(int ele_num)89   void set_max_elements_num(int ele_num) { tensor_list_c_.max_elements_num_ = ele_num; }
90 
max_elements_num()91   int max_elements_num() const { return tensor_list_c_.max_elements_num_; }
92 
93   static TensorList *CopyTensorList(const TensorList &src, bool copy_data = false,
94                                     const AllocatorPtr &allocator = nullptr);
95 
96   int MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape);
97 
98   int MallocData(const AllocatorPtr allocator = nullptr) override;
99 
100   int FreeTensorListData();
101 
102   void FreeData() override;
103 
104   int SetTensor(int index, const Tensor *src_tensor);
105 
106   Tensor *GetTensor(int index);
107 
set_tensors_data_type(TypeId type)108   void set_tensors_data_type(TypeId type) { tensor_list_c_.tensors_data_type_ = type; }
109 
tensors_data_type()110   TypeId tensors_data_type() const { return static_cast<TypeId>(tensor_list_c_.tensors_data_type_); }
111 
tensors()112   std::vector<Tensor *> tensors() { return tensors_; }
113 
set_tensors(const std::vector<Tensor * > & tensors)114   void set_tensors(const std::vector<Tensor *> &tensors) { this->tensors_ = tensors; }
115 
116   int CheckTensorListParam();
117 
118   bool IsCompatibleShape(const std::vector<int> &shape);
119 
120   bool IsCompatibleShape(const Tensor *src);
121 
122   STATUS Decode(const int *data, size_t length);
123 
124   bool IsConst() const override;
125 
set_init_ref_count(int ref_count)126   void set_init_ref_count(int ref_count) override {
127     this->init_ref_count_ = ref_count;
128     for (auto tensor : tensors_) {
129       if (tensor != nullptr) {
130         tensor->set_init_ref_count(ref_count);
131       }
132     }
133   }
134 
set_ref_count(int ref_count)135   void set_ref_count(int ref_count) override {
136     ref_count_ = ref_count;
137     for (auto tensor : tensors_) {
138       if (tensor != nullptr) {
139         tensor->set_ref_count(ref_count);
140       }
141     }
142   }
143 
ResetRefCount()144   void ResetRefCount() override {
145     set_ref_count(this->init_ref_count_);
146     for (auto tensor : tensors_) {
147       if (tensor != nullptr) {
148         tensor->set_ref_count(this->init_ref_count_);
149       }
150     }
151   }
152 
IncRefCount()153   void IncRefCount() override {
154     ++ref_count_;
155     for (auto tensor : tensors_) {
156       if (tensor != nullptr) {
157         tensor->IncRefCount();
158       }
159     }
160   }
161 
DecRefCount()162   void DecRefCount() override {
163     if (this->IsConst() || this->IsGraphInput()) {
164       return;
165     }
166     --ref_count_;
167     for (auto tensor : tensors_) {
168       if (tensor != nullptr) {
169         tensor->DecRefCount();
170       }
171     }
172   }
173 
set_allocator(AllocatorPtr allocator)174   void set_allocator(AllocatorPtr allocator) override {
175     allocator_ = allocator;
176     for (auto tensor : tensors_) {
177       if (tensor != nullptr) {
178         tensor->set_allocator(allocator);
179       }
180     }
181   }
182 
set_own_data(bool own_data)183   void set_own_data(bool own_data) override {
184     this->own_data_ = own_data;
185     for (auto tensor : tensors_) {
186       if (tensor != nullptr) {
187         tensor->set_own_data(own_data);
188       }
189     }
190   }
191 
ConvertToTensorListC()192   TensorListC *ConvertToTensorListC() {
193     tensor_list_c_.format_ = tensor_c_.format_;
194     tensor_list_c_.shape_value_ = tensor_c_.shape_size_ == 0 ? 0 : tensor_c_.shape_[0];
195     tensor_list_c_.element_num_ = tensor_c_.shape_size_ == 0 ? 0 : tensors_.size();
196     tensor_list_c_.tensors_ = nullptr;
197     return &tensor_list_c_;
198   }
199 
200  protected:
201   // The following functions must be masked.
data()202   void *data() const override { return nullptr; }
MutableData()203   void *MutableData() override { return nullptr; }
Size()204   size_t Size() const override { return 0; }
205   TensorListC tensor_list_c_;
206   std::vector<Tensor *> tensors_{};
207 };
208 
209 #else
210 
211 using TensorList = void;
212 
213 #endif
214 }  // namespace mindspore::lite
215 #endif  // MINDSPORE_LITE_SRC_TENSORLIST_H_
216