• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_TENSORLIST_H_
18 #define MINDSPORE_LITE_SRC_TENSORLIST_H_
19 
20 #include <memory>
21 #include <vector>
22 #include "include/ms_tensor.h"
23 #include "include/errorcode.h"
24 #include "src/common/log_adapter.h"
25 #include "schema/model_generated.h"
26 #include "src/tensor.h"
27 
28 namespace mindspore::lite {
29 /**
30  * Tensorlist is a container of vector, in which each element is a tensor object.
31  * Member objects:
32  *  1.tensors_: tensors_ is a vector, where each element is a pointer to tensor type.
33  *  2.shape_: represents the size of the tensors_ and shape_.size() must be equal to 1.
34  *  3.element_shape_: element_shape_ represents the shape of each tensor in tensors_.
35  *    Some dimensions can be negative, which means that the corresponding dimensions of each tensor in tensors_ can be
36  *    different.
37  *  4.data_type_: indicates that the tensorlist is a tensor of type kObjectTypeTensorType, so it can only be
38  *    "kObjectTypeTensorType"
39  *  5.tensors_data_type_: data_type_ of each tensor in tensors_
40  * Usage:
41  *  std::vector<int> shape = (1, 2);  // tensors_ only has two tensor
42  *  std::vector<int> element_shape = {-1, 99};
43  *  // dim0 is arbitrary and dim1 is must to be 99 of each tensor.shape() in tensors_
44  *  TensorList *tl = new TensorList(shape, element_shape);
45  *  std::vector<std::vector<int> > tensor_shape = std::vector<vector<int> > (2,
46  *                                                                          (std::vector<int> {5, 99},
47  *                                                                           std::vector<int> {1, 99}));
48  *  // tensor_shape[0] and tensor_shape[1] is not equal in dim0, but dim1 is must be equal to 99.
49  *  t1->MallocTensorListData(kNumberTypeFloat, tensor_shape);
50  *  t1->MallocData();
51  *  t1->...
52  *  ...
53  *  t1->FreeData();
54  *  t1->FreeTensorListData();
55  *
56  *  See the code for other constructors.
57  */
58 class TensorList : public Tensor {
59  public:
60   TensorList() = default;
61 
62   TensorList(std::vector<int> shape, std::vector<int> element_shape, Category category = VAR);
63 
64   ~TensorList() override;
65 
66   TensorList(const TensorList &other) = delete;
67 
68   TensorList &operator=(const TensorList &tl) = delete;
69 
set_element_shape(const std::vector<int> & shape)70   void set_element_shape(const std::vector<int> &shape) { element_shape_ = shape; }
71 
element_shape()72   std::vector<int> &element_shape() { return element_shape_; }
73 
set_max_elements_num(int ele_num)74   void set_max_elements_num(int ele_num) { max_elements_num_ = ele_num; }
75 
max_elements_num()76   int max_elements_num() const { return max_elements_num_; }
77 
78   int MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape);
79 
80   int MallocData(const AllocatorPtr allocator = nullptr) override;
81 
82   int FreeTensorListData();
83 
84   void FreeData() override;
85 
86   int CopyTensorList(const TensorList &src, bool copy_data);
87 
88   int CopyTensorData(const TensorList &src);
89 
90   int SetTensor(int index, const Tensor *src_tensor);
91 
92   Tensor *GetTensor(int index);
93 
set_tensors_data_type(TypeId type)94   void set_tensors_data_type(TypeId type) { tensors_data_type_ = type; }
95 
tensors_data_type()96   TypeId tensors_data_type() const { return tensors_data_type_; }
97 
tensors()98   std::vector<Tensor *> &tensors() { return tensors_; }
99 
set_tensors(const std::vector<Tensor * > & tensors)100   void set_tensors(const std::vector<Tensor *> &tensors) { this->tensors_ = tensors; }
101 
102   int CheckTensorListParam();
103 
104   bool IsCompatibleShape(const std::vector<int> &shape);
105 
106   bool IsCompatibleShape(const Tensor *src);
107 
108   STATUS Decode(const int *data);
109 
110   bool IsConst() const override;
111 
set_ref_count(int ref_count)112   void set_ref_count(int ref_count) override {
113     ref_count_ = ref_count;
114     for (auto tensor : tensors_) {
115       if (tensor != nullptr) {
116         tensor->set_ref_count(ref_count);
117       }
118     }
119   }
120 
ResetRefCount()121   void ResetRefCount() override {
122     set_ref_count(this->init_ref_count_);
123     for (auto tensor : tensors_) {
124       if (tensor != nullptr) {
125         tensor->set_ref_count(this->init_ref_count_);
126       }
127     }
128   }
129 
IncRefCount()130   void IncRefCount() override {
131     ++ref_count_;
132     for (auto tensor : tensors_) {
133       if (tensor != nullptr) {
134         tensor->IncRefCount();
135       }
136     }
137   }
138 
DecRefCount()139   void DecRefCount() override {
140     if (this->IsConst() || this->IsGraphInput()) {
141       return;
142     }
143     --ref_count_;
144     for (auto tensor : tensors_) {
145       if (tensor != nullptr) {
146         tensor->DecRefCount();
147       }
148     }
149   }
150 
set_allocator(AllocatorPtr allocator)151   void set_allocator(AllocatorPtr allocator) override {
152     allocator_ = allocator;
153     for (auto tensor : tensors_) {
154       if (tensor != nullptr) {
155         tensor->set_allocator(allocator);
156       }
157     }
158   }
159 
set_own_data(bool own_data)160   void set_own_data(bool own_data) override {
161     this->own_data_ = own_data;
162     for (auto tensor : tensors_) {
163       if (tensor != nullptr) {
164         tensor->set_own_data(own_data);
165       }
166     }
167   }
168 
169  protected:
170   // The following functions must be masked.
data()171   void *data() const override { return nullptr; }
MutableData()172   void *MutableData() override { return nullptr; }
Size()173   size_t Size() const override { return 0; }
174   std::vector<Tensor *> tensors_{};
175   TypeId tensors_data_type_ = kTypeUnknown;
176   std::vector<int> element_shape_{};
177   int max_elements_num_ = -1;
178 };
179 }  // namespace mindspore::lite
180 #endif  // MINDSPORE_LITE_SRC_TENSORLIST_H_
181