• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define MAX_INTEGER_DTYPE 9223372036854775807
17 
18 #include "minddata/dataset/core/tensor_shape.h"
19 
20 #include <limits>
21 
22 #include "utils/ms_utils.h"
23 #include "minddata/dataset/util/log_adapter.h"
24 #include "minddata/dataset/include/dataset/constants.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 constexpr dsize_t TensorShape::kDimUnknown;
29 
multi_ok(dsize_t x,dsize_t y)30 bool multi_ok(dsize_t x, dsize_t y) {
31   dsize_t p = x * y;
32   if (x == 0) {
33     return true;
34   }
35   return p / x == y;
36 }
37 
NumOfElements() const38 dsize_t TensorShape::NumOfElements() const {
39   if (!known() && strides_.size() < 1) {
40     return 0;
41   }
42   return strides_[0];
43 }
44 
Print(std::ostream & out) const45 void TensorShape::Print(std::ostream &out) const {
46   if (!known() && raw_shape_.empty()) {
47     out << "<kUnknown>";
48   } else {
49     out << "<";
50     for (auto i = 0; i < this->Rank(); i++) {
51       if (raw_shape_[i] == kDimUnknown) {
52         out << "*";
53       } else {
54         out << raw_shape_[i];
55       }
56       if (i != this->Rank() - 1) {
57         out << ",";
58       }
59     }
60     out << ">";
61   }
62 }
63 
TensorShape(const std::initializer_list<dsize_t> & list)64 TensorShape::TensorShape(const std::initializer_list<dsize_t> &list) { AddListToShape(list); }
65 
TensorShape(const std::vector<dsize_t> & list)66 TensorShape::TensorShape(const std::vector<dsize_t> &list) { AddListToShape(list); }
67 
TensorShape(const TensorShape & shape)68 TensorShape::TensorShape(const TensorShape &shape)
69     : raw_shape_(shape.raw_shape_), strides_(shape.strides_), known_(shape.known_) {}
70 
TensorShape(TensorShape && shape)71 TensorShape::TensorShape(TensorShape &&shape) noexcept
72     : raw_shape_(std::move(shape.raw_shape_)), strides_(std::move(shape.strides_)), known_(shape.known_) {}
73 
operator =(const TensorShape & shape)74 TensorShape &TensorShape::operator=(const TensorShape &shape) {
75   if (this != &shape) {
76     raw_shape_ = shape.raw_shape_;
77     strides_ = shape.strides_;
78     known_ = shape.known_;
79   }
80   return *this;
81 }
82 
operator =(TensorShape && shape)83 TensorShape &TensorShape::operator=(TensorShape &&shape) noexcept {
84   if (this != &shape) {
85     raw_shape_ = std::move(shape.raw_shape_);
86     strides_ = std::move(shape.strides_);
87     known_ = shape.known_;
88   }
89   return *this;
90 }
91 
92 #ifdef ENABLE_PYTHON
TensorShape(py::list l)93 TensorShape::TensorShape(py::list l) {
94   std::vector<dsize_t> list_c;
95   for (auto &i : l) {
96     if (!i.is_none()) {
97       list_c.push_back(i.cast<int>());
98     } else {
99       list_c.push_back(TensorShape::kDimUnknown);
100     }
101   }
102   AddListToShape(list_c);
103 }
104 #endif
105 
106 #ifndef ENABLE_ANDROID
TensorShape(cv::MatSize cv_size,uint32_t type)107 TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type) : known_(true) {
108   for (int i = 0; i < cv_size.dims(); i++) {
109     raw_shape_.push_back(cv_size[i]);
110   }
111   auto channels = static_cast<uint8_t>(1 + (type >> static_cast<uint8_t>(CV_CN_SHIFT)));
112   if (channels != 1) {
113     raw_shape_.push_back(channels);
114   }
115 }
116 #endif
117 
CreateUnknownRankShape()118 TensorShape TensorShape::CreateUnknownRankShape() {
119   TensorShape s({});
120   s.known_ = false;
121   return s;
122 }
123 
InsertDim(dsize_t axis,dsize_t dim) const124 TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const {
125   std::vector<dsize_t> tmp = AsVector();
126   (void)tmp.insert(tmp.begin() + axis, dim);
127   return TensorShape(tmp);
128 }
129 
AsVector() const130 std::vector<dsize_t> TensorShape::AsVector() const {
131   return std::vector<dsize_t>(raw_shape_.begin(), raw_shape_.end());
132 }
133 
IsValidIndex(const std::vector<dsize_t> & index) const134 bool TensorShape::IsValidIndex(const std::vector<dsize_t> &index) const {
135   dsize_t s_rank = Rank();
136   if (index.size() != s_rank) {
137     return false;
138   }
139   for (dsize_t i = 0; i < s_rank; i++) {
140     if (index[i] < 0 || raw_shape_[i] <= index[i]) {
141       return false;
142     }
143   }
144   return true;
145 }
146 
147 template <typename T>
AddListToShape(const T & list)148 void TensorShape::AddListToShape(const T &list) {
149   raw_shape_.resize(list.size());
150   strides_.resize(list.size() + 1);
151   strides_[list.size()] = 1;
152   known_ = true;
153   dsize_t size = 0;
154   auto itr = std::rbegin(list);  // iterate over the list in reverse order
155   auto s = list.size() - 1;      // to compute strides while adding dims
156   for (; itr != std::rend(list); itr++, s--) {
157     dsize_t dim = *itr;
158     if (dim > 0) {
159       if (strides_[s + 1] > std::numeric_limits<int64_t>::max() / dim) {
160         MS_LOG(ERROR) << "Invalid shape data, overflow occurred!";
161         known_ = false;
162         raw_shape_.clear();
163         return;
164       }
165       strides_[s] = dim * strides_[s + 1];
166     }
167     if (dim < 0) {
168       known_ = false;
169     }
170     if (dim > kDeMaxDim) {
171       std::stringstream ss;
172       ss << "Invalid shape data, dim (" << dim << ") is larger than the maximum dim size(" << kDeMaxDim << ")!";
173       MS_LOG(ERROR) << ss.str().c_str();
174       known_ = false;
175       raw_shape_.clear();
176       return;
177     }
178     raw_shape_[s] = dim;
179     size++;
180   }
181   if (size > kDeMaxRank) {
182     std::stringstream ss;
183     ss << "Invalid shape data, rank (" << size << ") is larger than the maximum rank size(" << kDeMaxRank << ").";
184     MS_LOG(ERROR) << ss.str().c_str();
185     known_ = false;
186     raw_shape_.clear();
187     return;
188   }
189 }
190 
CreateUnknownShapeWithRank(dsize_t rank)191 TensorShape TensorShape::CreateUnknownShapeWithRank(dsize_t rank) {
192   TensorShape s({});
193   for (dsize_t i = 0; i < rank; i++) {
194     s.raw_shape_.push_back(kDimUnknown);
195   }
196   s.known_ = false;
197   return s;
198 }
199 
PrependDim(dsize_t dim) const200 TensorShape TensorShape::PrependDim(dsize_t dim) const {
201   if (Size() == 0) {
202     return TensorShape({dim});
203   }
204   return InsertDim(0, dim);
205 }
206 
AppendDim(dsize_t dim) const207 TensorShape TensorShape::AppendDim(dsize_t dim) const {
208   auto vec = AsVector();
209   vec.push_back(dim);
210   return TensorShape(vec);
211 }
212 
213 #ifdef ENABLE_PYTHON
AsPyList()214 py::list TensorShape::AsPyList() {
215   py::list list;
216   for (auto i : raw_shape_) {
217     list.append(i);
218   }
219   return list;
220 }
221 #endif
222 
Squeeze() const223 TensorShape TensorShape::Squeeze() const {
224   std::vector<dsize_t> new_shape(raw_shape_.size());
225   auto it = std::copy_if(raw_shape_.begin(), raw_shape_.end(), new_shape.begin(), [](auto s) { return s != 1; });
226   new_shape.resize(std::distance(new_shape.begin(), it));
227   return TensorShape(new_shape);
228 }
229 
Strides() const230 std::vector<dsize_t> TensorShape::Strides() const { return std::vector<dsize_t>{strides_.begin() + 1, strides_.end()}; }
231 
232 // Name: ToFlatIndex()
233 // Description: convert a vector style index to number, used to access memory internal use only
ToFlatIndex(const std::vector<dsize_t> & index,dsize_t * flat_index) const234 Status TensorShape::ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const {
235   RETURN_UNEXPECTED_IF_NULL(flat_index);
236   if (index.size() != raw_shape_.size()) {
237     std::stringstream ss;
238     ss << "Index size (" << index.size() << ") does not match the shape size (" << raw_shape_.size() << ").";
239     RETURN_STATUS_UNEXPECTED(ss.str());
240   }
241   *flat_index = 0;
242   for (size_t k = 0; k < index.size(); k++) {
243     *flat_index +=
244       (index[k] == 0) ? 0 : index[k] * strides_[k + 1];  // skip the first element of strides_ which is numOfElements
245   }
246   CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index");
247   return Status::OK();
248 }
249 }  // namespace dataset
250 }  // namespace mindspore
251