1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define MAX_INTEGER_DTYPE 9223372036854775807
17
18 #include "minddata/dataset/core/tensor_shape.h"
19
20 #include <limits>
21
22 #include "utils/ms_utils.h"
23 #ifndef ENABLE_ANDROID
24 #include "utils/log_adapter.h"
25 #else
26 #include "mindspore/lite/src/common/log_adapter.h"
27 #endif
28 #include "minddata/dataset/include/dataset/constants.h"
29
30 namespace mindspore {
31 namespace dataset {
32 constexpr dsize_t TensorShape::kDimUnknown;
33
multi_ok(dsize_t x,dsize_t y)34 bool multi_ok(dsize_t x, dsize_t y) {
35 dsize_t p = x * y;
36 if (x == 0) {
37 return true;
38 }
39 return p / x == y;
40 }
41
NumOfElements() const42 dsize_t TensorShape::NumOfElements() const {
43 if (!known() && strides_.size() < 1) {
44 return 0;
45 }
46 return strides_[0];
47 }
48
Print(std::ostream & out) const49 void TensorShape::Print(std::ostream &out) const {
50 if (!known() && raw_shape_.empty()) {
51 out << "<kUnknown>";
52 } else {
53 out << "<";
54 for (auto i = 0; i < this->Rank(); i++) {
55 if (raw_shape_[i] == kDimUnknown) {
56 out << "*";
57 } else {
58 out << raw_shape_[i];
59 }
60 if (i != this->Rank() - 1) {
61 out << ",";
62 }
63 }
64 out << ">";
65 }
66 }
67
TensorShape(const std::initializer_list<dsize_t> & list)68 TensorShape::TensorShape(const std::initializer_list<dsize_t> &list)
69 : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
70 AddListToShape(list);
71 }
72
TensorShape(const std::vector<dsize_t> & list)73 TensorShape::TensorShape(const std::vector<dsize_t> &list)
74 : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
75 AddListToShape(list);
76 }
77
TensorShape(const TensorShape & shape)78 TensorShape::TensorShape(const TensorShape &shape)
79 : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
80 AddListToShape(shape.AsVector());
81 known_ = shape.known_; // override with the input shape in case of unknown-rank tensor shape.
82 }
83
84 #ifdef ENABLE_PYTHON
TensorShape(py::list l)85 TensorShape::TensorShape(py::list l)
86 : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
87 std::vector<dsize_t> list_c;
88 for (auto &i : l) {
89 if (!i.is_none()) {
90 list_c.push_back(i.cast<int>());
91 } else {
92 list_c.push_back(TensorShape::kDimUnknown);
93 }
94 }
95 AddListToShape(list_c);
96 }
97 #endif
98
99 #ifndef ENABLE_ANDROID
TensorShape(cv::MatSize cv_size,uint32_t type)100 TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type)
101 : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
102 for (int i = 0; i < cv_size.dims(); i++) {
103 raw_shape_.push_back(cv_size[i]);
104 }
105 auto channels = static_cast<uint8_t>(1 + (type >> static_cast<uint8_t>(CV_CN_SHIFT)));
106 if (channels != 1) {
107 raw_shape_.push_back(channels);
108 }
109 known_ = true;
110 }
111 #endif
112
CreateUnknownRankShape()113 TensorShape TensorShape::CreateUnknownRankShape() {
114 TensorShape s({});
115 s.known_ = false;
116 return s;
117 }
118
InsertDim(dsize_t axis,dsize_t dim) const119 TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const {
120 std::vector<dsize_t> tmp = AsVector();
121 (void)tmp.insert(tmp.begin() + axis, dim);
122 return TensorShape(tmp);
123 }
124
AsVector() const125 std::vector<dsize_t> TensorShape::AsVector() const {
126 return std::vector<dsize_t>(raw_shape_.begin(), raw_shape_.end());
127 }
128
IsValidIndex(const std::vector<dsize_t> & index) const129 bool TensorShape::IsValidIndex(const std::vector<dsize_t> &index) const {
130 dsize_t s_rank = Rank();
131 if (index.size() != s_rank) {
132 return false;
133 }
134 for (dsize_t i = 0; i < s_rank; i++) {
135 if (index[i] < 0 || raw_shape_[i] <= index[i]) {
136 return false;
137 }
138 }
139 return true;
140 }
141
142 template <typename T>
AddListToShape(const T & list)143 void TensorShape::AddListToShape(const T &list) {
144 raw_shape_.resize(list.size());
145 strides_.resize(list.size() + 1);
146 strides_[list.size()] = 1;
147 known_ = true;
148 dsize_t size = 0;
149 auto itr = std::rbegin(list); // iterate over the list in reverse order
150 auto s = list.size() - 1; // to compute strides while adding dims
151 for (; itr != std::rend(list); itr++, s--) {
152 dsize_t dim = *itr;
153 if (dim > 0) {
154 if (strides_[s + 1] > std::numeric_limits<int64_t>::max() / dim) {
155 MS_LOG(ERROR) << "Invalid shape data, overflow occurred!";
156 known_ = false;
157 raw_shape_.clear();
158 return;
159 }
160 strides_[s] = dim * strides_[s + 1];
161 }
162 if (dim < 0) {
163 known_ = false;
164 }
165 if (dim > kDeMaxDim) {
166 std::stringstream ss;
167 ss << "Invalid shape data, dim (" << dim << ") is larger than the maximum dim size(" << kDeMaxDim << ")!";
168 MS_LOG(ERROR) << ss.str().c_str();
169 known_ = false;
170 raw_shape_.clear();
171 return;
172 }
173 raw_shape_[s] = dim;
174 size++;
175 }
176 if (size > kDeMaxRank) {
177 std::stringstream ss;
178 ss << "Invalid shape data, rank (" << size << ") is larger than the maximum rank size(" << kDeMaxRank << ").";
179 MS_LOG(ERROR) << ss.str().c_str();
180 known_ = false;
181 raw_shape_.clear();
182 return;
183 }
184 }
185
CreateUnknownShapeWithRank(dsize_t rank)186 TensorShape TensorShape::CreateUnknownShapeWithRank(dsize_t rank) {
187 TensorShape s({});
188 for (dsize_t i = 0; i < rank; i++) {
189 s.raw_shape_.push_back(kDimUnknown);
190 }
191 s.known_ = false;
192 return s;
193 }
194
PrependDim(dsize_t dim) const195 TensorShape TensorShape::PrependDim(dsize_t dim) const {
196 if (Size() == 0) {
197 return TensorShape({dim});
198 }
199 return InsertDim(0, dim);
200 }
201
AppendDim(dsize_t dim) const202 TensorShape TensorShape::AppendDim(dsize_t dim) const {
203 auto vec = AsVector();
204 vec.push_back(dim);
205 return TensorShape(vec);
206 }
207
208 #ifdef ENABLE_PYTHON
AsPyList()209 py::list TensorShape::AsPyList() {
210 py::list list;
211 for (auto i : raw_shape_) {
212 list.append(i);
213 }
214 return list;
215 }
216 #endif
217
Squeeze() const218 TensorShape TensorShape::Squeeze() const {
219 std::vector<dsize_t> new_shape(raw_shape_.size());
220 auto it = std::copy_if(raw_shape_.begin(), raw_shape_.end(), new_shape.begin(), [](auto s) { return s != 1; });
221 new_shape.resize(std::distance(new_shape.begin(), it));
222 return TensorShape(new_shape);
223 }
224
Strides() const225 std::vector<dsize_t> TensorShape::Strides() const { return std::vector<dsize_t>{strides_.begin() + 1, strides_.end()}; }
226
227 // Name: ToFlatIndex()
228 // Description: convert a vector style index to number, used to access memory internal use only
ToFlatIndex(const std::vector<dsize_t> & index,dsize_t * flat_index) const229 Status TensorShape::ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const {
230 RETURN_UNEXPECTED_IF_NULL(flat_index);
231 if (index.size() != raw_shape_.size()) {
232 std::stringstream ss;
233 ss << "Index size (" << index.size() << ") does not match the shape size (" << raw_shape_.size() << ").";
234 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, ss.str());
235 }
236 *flat_index = 0;
237 for (size_t k = 0; k < index.size(); k++) {
238 *flat_index +=
239 (index[k] == 0) ? 0 : index[k] * strides_[k + 1]; // skip the first element of strides_ which is numOfElements
240 }
241 CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index");
242 return Status::OK();
243 }
244 } // namespace dataset
245 } // namespace mindspore
246