1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define MAX_INTEGER_DTYPE 9223372036854775807
17
18 #include "minddata/dataset/core/tensor_shape.h"
19
20 #include <limits>
21
22 #include "utils/ms_utils.h"
23 #include "minddata/dataset/util/log_adapter.h"
24 #include "minddata/dataset/include/dataset/constants.h"
25
26 namespace mindspore {
27 namespace dataset {
28 constexpr dsize_t TensorShape::kDimUnknown;
29
multi_ok(dsize_t x,dsize_t y)30 bool multi_ok(dsize_t x, dsize_t y) {
31 dsize_t p = x * y;
32 if (x == 0) {
33 return true;
34 }
35 return p / x == y;
36 }
37
NumOfElements() const38 dsize_t TensorShape::NumOfElements() const {
39 if (!known() && strides_.size() < 1) {
40 return 0;
41 }
42 return strides_[0];
43 }
44
Print(std::ostream & out) const45 void TensorShape::Print(std::ostream &out) const {
46 if (!known() && raw_shape_.empty()) {
47 out << "<kUnknown>";
48 } else {
49 out << "<";
50 for (auto i = 0; i < this->Rank(); i++) {
51 if (raw_shape_[i] == kDimUnknown) {
52 out << "*";
53 } else {
54 out << raw_shape_[i];
55 }
56 if (i != this->Rank() - 1) {
57 out << ",";
58 }
59 }
60 out << ">";
61 }
62 }
63
TensorShape(const std::initializer_list<dsize_t> & list)64 TensorShape::TensorShape(const std::initializer_list<dsize_t> &list) { AddListToShape(list); }
65
TensorShape(const std::vector<dsize_t> & list)66 TensorShape::TensorShape(const std::vector<dsize_t> &list) { AddListToShape(list); }
67
TensorShape(const TensorShape & shape)68 TensorShape::TensorShape(const TensorShape &shape)
69 : raw_shape_(shape.raw_shape_), strides_(shape.strides_), known_(shape.known_) {}
70
TensorShape(TensorShape && shape)71 TensorShape::TensorShape(TensorShape &&shape) noexcept
72 : raw_shape_(std::move(shape.raw_shape_)), strides_(std::move(shape.strides_)), known_(shape.known_) {}
73
operator =(const TensorShape & shape)74 TensorShape &TensorShape::operator=(const TensorShape &shape) {
75 if (this != &shape) {
76 raw_shape_ = shape.raw_shape_;
77 strides_ = shape.strides_;
78 known_ = shape.known_;
79 }
80 return *this;
81 }
82
operator =(TensorShape && shape)83 TensorShape &TensorShape::operator=(TensorShape &&shape) noexcept {
84 if (this != &shape) {
85 raw_shape_ = std::move(shape.raw_shape_);
86 strides_ = std::move(shape.strides_);
87 known_ = shape.known_;
88 }
89 return *this;
90 }
91
92 #ifdef ENABLE_PYTHON
TensorShape(py::list l)93 TensorShape::TensorShape(py::list l) {
94 std::vector<dsize_t> list_c;
95 for (auto &i : l) {
96 if (!i.is_none()) {
97 list_c.push_back(i.cast<int>());
98 } else {
99 list_c.push_back(TensorShape::kDimUnknown);
100 }
101 }
102 AddListToShape(list_c);
103 }
104 #endif
105
106 #ifndef ENABLE_ANDROID
TensorShape(cv::MatSize cv_size,uint32_t type)107 TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type) : known_(true) {
108 for (int i = 0; i < cv_size.dims(); i++) {
109 raw_shape_.push_back(cv_size[i]);
110 }
111 auto channels = static_cast<uint8_t>(1 + (type >> static_cast<uint8_t>(CV_CN_SHIFT)));
112 if (channels != 1) {
113 raw_shape_.push_back(channels);
114 }
115 }
116 #endif
117
CreateUnknownRankShape()118 TensorShape TensorShape::CreateUnknownRankShape() {
119 TensorShape s({});
120 s.known_ = false;
121 return s;
122 }
123
InsertDim(dsize_t axis,dsize_t dim) const124 TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const {
125 std::vector<dsize_t> tmp = AsVector();
126 (void)tmp.insert(tmp.begin() + axis, dim);
127 return TensorShape(tmp);
128 }
129
AsVector() const130 std::vector<dsize_t> TensorShape::AsVector() const {
131 return std::vector<dsize_t>(raw_shape_.begin(), raw_shape_.end());
132 }
133
IsValidIndex(const std::vector<dsize_t> & index) const134 bool TensorShape::IsValidIndex(const std::vector<dsize_t> &index) const {
135 dsize_t s_rank = Rank();
136 if (index.size() != s_rank) {
137 return false;
138 }
139 for (dsize_t i = 0; i < s_rank; i++) {
140 if (index[i] < 0 || raw_shape_[i] <= index[i]) {
141 return false;
142 }
143 }
144 return true;
145 }
146
147 template <typename T>
AddListToShape(const T & list)148 void TensorShape::AddListToShape(const T &list) {
149 raw_shape_.resize(list.size());
150 strides_.resize(list.size() + 1);
151 strides_[list.size()] = 1;
152 known_ = true;
153 dsize_t size = 0;
154 auto itr = std::rbegin(list); // iterate over the list in reverse order
155 auto s = list.size() - 1; // to compute strides while adding dims
156 for (; itr != std::rend(list); itr++, s--) {
157 dsize_t dim = *itr;
158 if (dim > 0) {
159 if (strides_[s + 1] > std::numeric_limits<int64_t>::max() / dim) {
160 MS_LOG(ERROR) << "Invalid shape data, overflow occurred!";
161 known_ = false;
162 raw_shape_.clear();
163 return;
164 }
165 strides_[s] = dim * strides_[s + 1];
166 }
167 if (dim < 0) {
168 known_ = false;
169 }
170 if (dim > kDeMaxDim) {
171 std::stringstream ss;
172 ss << "Invalid shape data, dim (" << dim << ") is larger than the maximum dim size(" << kDeMaxDim << ")!";
173 MS_LOG(ERROR) << ss.str().c_str();
174 known_ = false;
175 raw_shape_.clear();
176 return;
177 }
178 raw_shape_[s] = dim;
179 size++;
180 }
181 if (size > kDeMaxRank) {
182 std::stringstream ss;
183 ss << "Invalid shape data, rank (" << size << ") is larger than the maximum rank size(" << kDeMaxRank << ").";
184 MS_LOG(ERROR) << ss.str().c_str();
185 known_ = false;
186 raw_shape_.clear();
187 return;
188 }
189 }
190
CreateUnknownShapeWithRank(dsize_t rank)191 TensorShape TensorShape::CreateUnknownShapeWithRank(dsize_t rank) {
192 TensorShape s({});
193 for (dsize_t i = 0; i < rank; i++) {
194 s.raw_shape_.push_back(kDimUnknown);
195 }
196 s.known_ = false;
197 return s;
198 }
199
PrependDim(dsize_t dim) const200 TensorShape TensorShape::PrependDim(dsize_t dim) const {
201 if (Size() == 0) {
202 return TensorShape({dim});
203 }
204 return InsertDim(0, dim);
205 }
206
AppendDim(dsize_t dim) const207 TensorShape TensorShape::AppendDim(dsize_t dim) const {
208 auto vec = AsVector();
209 vec.push_back(dim);
210 return TensorShape(vec);
211 }
212
213 #ifdef ENABLE_PYTHON
AsPyList()214 py::list TensorShape::AsPyList() {
215 py::list list;
216 for (auto i : raw_shape_) {
217 list.append(i);
218 }
219 return list;
220 }
221 #endif
222
Squeeze() const223 TensorShape TensorShape::Squeeze() const {
224 std::vector<dsize_t> new_shape(raw_shape_.size());
225 auto it = std::copy_if(raw_shape_.begin(), raw_shape_.end(), new_shape.begin(), [](auto s) { return s != 1; });
226 new_shape.resize(std::distance(new_shape.begin(), it));
227 return TensorShape(new_shape);
228 }
229
Strides() const230 std::vector<dsize_t> TensorShape::Strides() const { return std::vector<dsize_t>{strides_.begin() + 1, strides_.end()}; }
231
232 // Name: ToFlatIndex()
233 // Description: convert a vector style index to number, used to access memory internal use only
ToFlatIndex(const std::vector<dsize_t> & index,dsize_t * flat_index) const234 Status TensorShape::ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const {
235 RETURN_UNEXPECTED_IF_NULL(flat_index);
236 if (index.size() != raw_shape_.size()) {
237 std::stringstream ss;
238 ss << "Index size (" << index.size() << ") does not match the shape size (" << raw_shape_.size() << ").";
239 RETURN_STATUS_UNEXPECTED(ss.str());
240 }
241 *flat_index = 0;
242 for (size_t k = 0; k < index.size(); k++) {
243 *flat_index +=
244 (index[k] == 0) ? 0 : index[k] * strides_[k + 1]; // skip the first element of strides_ which is numOfElements
245 }
246 CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index");
247 return Status::OK();
248 }
249 } // namespace dataset
250 } // namespace mindspore
251