1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ 19 20 #include <deque> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "minddata/dataset/core/tensor.h" 26 27 namespace mindspore { 28 namespace dataset { 29 30 class TensorRow; // A set of Tensor pointers with an id 31 using TensorTable = std::vector<TensorRow>; // The table of tensors is a vector of rows 32 using TensorQTable = std::deque<TensorRow>; // A different flavour of tensor table, this one has queue functionality 33 34 class TensorRow { 35 public: 36 static constexpr row_id_type kDefaultRowId = -1; // Default row id 37 38 enum TensorRowFlags : uint32_t { 39 kFlagNone = 0, 40 kFlagEOE = 1U, // The row is an eoe end-of-epoch msg 41 kFlagEOF = 1U << 1, // The row is an eof end-of-data msg 42 kFlagWait = 1U << 2, // The row is an control signal for workers to suspend operations 43 kFlagQuit = 1U << 3, // The row is a control signal for workers to quit 44 kFlagSkip = 1U << 4, // The row is a control signal for workers to skip this row 45 kFlagError = 1U << 5 // The row is an error row (needs to be replaced with another row or skipped, as per 46 // ErrorSamplesMode config) 47 }; 48 49 // Type definitions 50 using size_type = size_t; 51 using value_type = std::shared_ptr<Tensor>; 52 using reference = std::shared_ptr<Tensor> &; 53 using const_reference = const std::shared_ptr<Tensor> &; 54 using vector_type = std::vector<std::shared_ptr<Tensor>>; 55 using iterator = std::vector<std::shared_ptr<Tensor>>::iterator; 56 using const_iterator = std::vector<std::shared_ptr<Tensor>>::const_iterator; 57 58 TensorRow() noexcept; 59 60 TensorRow(size_type n, const value_type &t) noexcept; 61 62 // Copy Constructors 63 explicit TensorRow(const vector_type &v); 64 65 TensorRow(row_id_type id, const std::initializer_list<value_type> &lst); 66 67 TensorRow(const TensorRow &tr); 68 69 TensorRow &operator=(const TensorRow &tr); 70 71 TensorRow &operator=(const std::initializer_list<value_type> &lst); 72 73 // Move Constructors 74 explicit TensorRow(vector_type &&v) noexcept; 75 76 TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept; 77 78 TensorRow(TensorRow &&tr) noexcept; 79 80 TensorRow &operator=(TensorRow &&tr) noexcept; 81 82 TensorRow &operator=(std::initializer_list<value_type> &&lst) noexcept; 83 84 // Destructor 85 ~TensorRow() = default; 86 87 // Deep copy 88 /// \param[in] new_tr the tensor row to clone to 89 Status Clone(TensorRow *new_tr) const; 90 91 /// Convert a vector of primitive types to a TensorRow consisting of one 1-D Tensor with the shape n. 92 /// \tparam `T` 93 /// \param[in] o input vector 94 /// \param[out] output TensorRow 95 template <typename T> ConvertToTensorRow(const std::vector<T> & o,TensorRow * output)96 static Status ConvertToTensorRow(const std::vector<T> &o, TensorRow *output) { 97 RETURN_UNEXPECTED_IF_NULL(output); 98 DataType data_type = DataType::FromCType<T>(); 99 if (data_type == DataType::DE_UNKNOWN) { 100 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized."); 101 } 102 if (data_type.IsString()) { 103 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string and bytes are not supported."); 104 } 105 std::shared_ptr<Tensor> tensor; 106 RETURN_IF_NOT_OK(Tensor::CreateFromVector(o, &tensor)); 107 output->push_back(tensor); 108 return Status::OK(); 109 } 110 111 /// Convert a single primitive type to a TensorRow consisting of one single data Tensor. 112 /// \tparam `T` 113 /// \param[in] o input 114 /// \param[out] output TensorRow 115 template <typename T> ConvertToTensorRow(const T & o,TensorRow * output)116 static Status ConvertToTensorRow(const T &o, TensorRow *output) { 117 RETURN_UNEXPECTED_IF_NULL(output); 118 DataType data_type = DataType::FromCType<T>(); 119 if (data_type == DataType::DE_UNKNOWN) { 120 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized."); 121 } 122 if (data_type.IsString()) { 123 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string and bytes are not supported."); 124 } 125 std::shared_ptr<Tensor> tensor; 126 RETURN_IF_NOT_OK(Tensor::CreateScalar(o, &tensor)); 127 output->push_back(tensor); 128 return Status::OK(); 129 } 130 131 /// Return the value in a TensorRow consisting of 1 single data Tensor. 132 /// \tparam `T` 133 /// \param[in] input TensorRow 134 /// \param[out] o the primitive variable 135 template <typename T> ConvertFromTensorRow(const TensorRow & input,T * o)136 static Status ConvertFromTensorRow(const TensorRow &input, T *o) { 137 RETURN_UNEXPECTED_IF_NULL(o); 138 DataType data_type = DataType::FromCType<T>(); 139 RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type)); 140 if (input.at(0)->type() != data_type) { 141 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type."); 142 } 143 if (input.at(0)->shape() != TensorShape({})) { 144 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must be a scalar tensor."); 145 } 146 return input.at(0)->GetItemAt(o, {}); 147 } 148 149 /// Convert a TensorRow consisting of one 1-D tensor to a vector of size n. 150 /// \tparam `T` 151 /// \param[in] o TensorRow consisting of one 1-D tensor 152 /// \param[out] o vector of primitive variable 153 template <typename T> ConvertFromTensorRow(const TensorRow & input,std::vector<T> * o)154 static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) { 155 RETURN_UNEXPECTED_IF_NULL(o); 156 DataType data_type = DataType::FromCType<T>(); 157 RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type)); 158 if (input.at(0)->Rank() != 1) { 159 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a rank of 1."); 160 } 161 for (auto it = input.at(0)->begin<T>(); it != input.at(0)->end<T>(); it++) { 162 o->push_back(*it); 163 } 164 return Status::OK(); 165 } 166 167 // Functions to fetch/set id/vector getId()168 row_id_type getId() const { return id_; } 169 setId(row_id_type id)170 void setId(row_id_type id) { id_ = id; } 171 getPath()172 std::vector<std::string> getPath() const { return path_; } 173 setPath(const std::vector<std::string> & path)174 void setPath(const std::vector<std::string> &path) { path_ = path; } 175 getRow()176 const vector_type &getRow() const { return row_; } 177 SizeInBytes()178 dsize_t SizeInBytes() const { 179 dsize_t sz = 0; 180 for (auto &it : row_) { 181 sz += it->SizeInBytes(); 182 } 183 return sz; 184 } 185 186 // Wrapper functions to support vector operations emplace_back(value_type t)187 void emplace_back(value_type t) { row_.emplace_back(t); } 188 push_back(value_type t)189 void push_back(value_type t) { row_.push_back(t); } 190 clear()191 void clear() noexcept { row_.clear(); } 192 193 // Reset both the tensor row vector and flags reset()194 void reset() noexcept { 195 row_.clear(); 196 tensor_row_flag_ = kFlagNone; 197 } 198 size()199 size_type size() const noexcept { return row_.size(); } 200 reserve(size_type size)201 void reserve(size_type size) { row_.reserve(size); } 202 resize(size_type size)203 void resize(size_type size) { row_.resize(size); } 204 empty()205 const bool empty() { return row_.empty(); } 206 insert(const_iterator position,const_iterator first,const_iterator last)207 void insert(const_iterator position, const_iterator first, const_iterator last) { 208 row_.insert(position, first, last); 209 } 210 211 // Wrapper functions to support vector element access at(size_type index)212 reference at(size_type index) { return row_.at(index); } 213 at(size_type index)214 const_reference at(size_type index) const { return row_.at(index); } 215 front()216 reference front() { return row_.front(); } 217 front()218 const_reference front() const { return row_.front(); } 219 back()220 reference back() { return row_.back(); } 221 back()222 const_reference back() const { return row_.back(); } 223 224 reference operator[](size_type index) { return row_[index]; } 225 226 const_reference operator[](size_type index) const { return row_[index]; } 227 228 // Wrapper functions to support vector iteration begin()229 iterator begin() { return row_.begin(); } 230 begin()231 const_iterator begin() const { return row_.begin(); } 232 end()233 iterator end() { return row_.end(); } 234 end()235 const_iterator end() const { return row_.end(); } 236 237 // Convenience getter functions for flag checking eof()238 bool eof() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOF)); } 239 eoe()240 bool eoe() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOE)); } 241 wait()242 bool wait() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagWait)); } 243 quit()244 bool quit() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagQuit)); } 245 skip()246 bool skip() const { 247 return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagSkip)); 248 } 249 error()250 bool error() const { 251 return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagError)); 252 } 253 Flags()254 TensorRowFlags Flags() const { return tensor_row_flag_; } 255 FlagName()256 std::string FlagName() const { 257 switch (tensor_row_flag_) { 258 case TensorRowFlags::kFlagNone: 259 return "Data"; 260 case TensorRowFlags::kFlagEOE: 261 return "EOE"; 262 case TensorRowFlags::kFlagEOF: 263 return "EOF"; 264 case TensorRowFlags::kFlagWait: 265 return "Wait"; 266 case TensorRowFlags::kFlagQuit: 267 return "Quit"; 268 case TensorRowFlags::kFlagSkip: 269 return "Skip"; 270 case TensorRowFlags::kFlagError: 271 return "Error"; 272 default: 273 return "Unknown"; 274 } 275 } 276 277 explicit TensorRow(TensorRowFlags flag); 278 279 protected: 280 row_id_type id_; 281 std::vector<std::string> path_; 282 std::vector<std::shared_ptr<Tensor>> row_; 283 284 TensorRowFlags tensor_row_flag_; 285 286 private: 287 /// Validate data type of TensorRow for conversions. 288 /// \param[in] input TensorRow 289 /// \param[in] data_type data type of the tensor row 290 static Status ValidateTensorRow(const TensorRow &input, const DataType &data_type); 291 }; 292 } // namespace dataset 293 } // namespace mindspore 294 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ 295