1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ 19 20 #include <deque> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "minddata/dataset/core/tensor.h" 26 27 namespace mindspore { 28 namespace dataset { 29 30 class TensorRow; // A set of Tensor pointers with an id 31 using TensorTable = std::vector<TensorRow>; // The table of tensors is a vector of rows 32 using TensorQTable = std::deque<TensorRow>; // A different flavour of tensor table, this one has queue functionality 33 34 class TensorRow { 35 public: 36 static constexpr row_id_type kDefaultRowId = -1; // Default row id 37 38 enum TensorRowFlags : uint32_t { 39 kFlagNone = 0, 40 kFlagEOF = 1, // The row is an eof end-of-data msg 41 kFlagEOE = 1u << 1, // The row is an eoe end-of-epoch msg 42 kFlagWait = 1u << 2, // The row is an control signal for workers to suspend operations 43 kFlagQuit = 1u << 3 // The row is a control signal for workers to quit 44 }; 45 46 // Type definitions 47 using size_type = size_t; 48 using value_type = std::shared_ptr<Tensor>; 49 using reference = std::shared_ptr<Tensor> &; 50 using const_reference = const std::shared_ptr<Tensor> &; 51 using vector_type = std::vector<std::shared_ptr<Tensor>>; 52 using iterator = std::vector<std::shared_ptr<Tensor>>::iterator; 53 using const_iterator = std::vector<std::shared_ptr<Tensor>>::const_iterator; 54 55 TensorRow() noexcept; 56 57 TensorRow(size_type n, const value_type &t) noexcept; 58 59 // Copy Constructors 60 explicit TensorRow(const vector_type &v); 61 62 TensorRow(row_id_type id, const std::initializer_list<value_type> &lst); 63 64 TensorRow(const TensorRow &tr); 65 66 TensorRow &operator=(const TensorRow &tr); 67 68 TensorRow &operator=(const std::initializer_list<value_type> &lst); 69 70 // Move Constructors 71 explicit TensorRow(vector_type &&v) noexcept; 72 73 TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept; 74 75 TensorRow(TensorRow &&tr) noexcept; 76 77 TensorRow &operator=(TensorRow &&tr) noexcept; 78 79 TensorRow &operator=(std::initializer_list<value_type> &&lst) noexcept; 80 81 // Destructor 82 ~TensorRow() = default; 83 84 /// Convert a vector of primitive types to a TensorRow consisting of one 1-D Tensor with the shape n. 85 /// \tparam `T` 86 /// \param[in] o input vector 87 /// \param[out] output TensorRow 88 template <typename T> ConvertToTensorRow(const std::vector<T> & o,TensorRow * output)89 static Status ConvertToTensorRow(const std::vector<T> &o, TensorRow *output) { 90 RETURN_UNEXPECTED_IF_NULL(output); 91 DataType data_type = DataType::FromCType<T>(); 92 if (data_type == DataType::DE_UNKNOWN) { 93 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized."); 94 } 95 if (data_type == DataType::DE_STRING) { 96 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported."); 97 } 98 std::shared_ptr<Tensor> tensor; 99 RETURN_IF_NOT_OK(Tensor::CreateFromVector(o, &tensor)); 100 output->push_back(tensor); 101 return Status::OK(); 102 } 103 104 /// Convert a single primitive type to a TensorRow consisting of one single data Tensor. 105 /// \tparam `T` 106 /// \param[in] o input 107 /// \param[out] output TensorRow 108 template <typename T> ConvertToTensorRow(const T & o,TensorRow * output)109 static Status ConvertToTensorRow(const T &o, TensorRow *output) { 110 RETURN_UNEXPECTED_IF_NULL(output); 111 DataType data_type = DataType::FromCType<T>(); 112 if (data_type == DataType::DE_UNKNOWN) { 113 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized."); 114 } 115 if (data_type == DataType::DE_STRING) { 116 RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported."); 117 } 118 std::shared_ptr<Tensor> tensor; 119 RETURN_IF_NOT_OK(Tensor::CreateScalar(o, &tensor)); 120 output->push_back(tensor); 121 return Status::OK(); 122 } 123 124 /// Return the value in a TensorRow consisting of 1 single data Tensor. 125 /// \tparam `T` 126 /// \param[in] input TensorRow 127 /// \param[out] o the primitive variable 128 template <typename T> ConvertFromTensorRow(const TensorRow & input,T * o)129 static Status ConvertFromTensorRow(const TensorRow &input, T *o) { 130 RETURN_UNEXPECTED_IF_NULL(o); 131 DataType data_type = DataType::FromCType<T>(); 132 RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type)); 133 if (input.at(0)->type() != data_type) { 134 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type."); 135 } 136 if (input.at(0)->shape() != TensorShape({})) { 137 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must be a scalar tensor."); 138 } 139 return input.at(0)->GetItemAt(o, {}); 140 } 141 142 /// Convert a TensorRow consisting of one 1-D tensor to a vector of size n. 143 /// \tparam `T` 144 /// \param[in] o TensorRow consisting of one 1-D tensor 145 /// \param[out] o vector of primitive variable 146 template <typename T> ConvertFromTensorRow(const TensorRow & input,std::vector<T> * o)147 static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) { 148 RETURN_UNEXPECTED_IF_NULL(o); 149 DataType data_type = DataType::FromCType<T>(); 150 RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type)); 151 if (input.at(0)->Rank() != 1) 152 RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a rank of 1."); 153 for (auto it = input.at(0)->begin<T>(); it != input.at(0)->end<T>(); it++) { 154 o->push_back(*it); 155 } 156 return Status::OK(); 157 } 158 159 // Functions to fetch/set id/vector getId()160 row_id_type getId() const { return id_; } 161 setId(row_id_type id)162 void setId(row_id_type id) { id_ = id; } 163 getPath()164 std::vector<std::string> getPath() const { return path_; } 165 setPath(std::vector<std::string> path)166 void setPath(std::vector<std::string> path) { path_ = path; } 167 getRow()168 const vector_type &getRow() const { return row_; } 169 SizeInBytes()170 int64_t SizeInBytes() const { 171 size_t sz = 0; 172 for (auto &it : row_) { 173 sz += it->SizeInBytes(); 174 } 175 return sz; 176 } 177 178 // Wrapper functions to support vector operations emplace_back(value_type t)179 void emplace_back(value_type t) { row_.emplace_back(t); } 180 push_back(value_type t)181 void push_back(value_type t) { row_.push_back(t); } 182 clear()183 void clear() noexcept { row_.clear(); } 184 size()185 size_type size() const noexcept { return row_.size(); } 186 reserve(size_type size)187 void reserve(size_type size) { row_.reserve(size); } 188 resize(size_type size)189 void resize(size_type size) { row_.resize(size); } 190 empty()191 bool empty() { return row_.empty(); } 192 insert(iterator position,iterator first,iterator last)193 void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); } 194 195 // Wrapper functions to support vector element access at(size_type index)196 reference at(size_type index) { return row_.at(index); } 197 at(size_type index)198 const_reference at(size_type index) const { return row_.at(index); } 199 front()200 reference front() { return row_.front(); } 201 front()202 const_reference front() const { return row_.front(); } 203 back()204 reference back() { return row_.back(); } 205 back()206 const_reference back() const { return row_.back(); } 207 208 reference operator[](size_type index) { return row_[index]; } 209 210 const_reference operator[](size_type index) const { return row_[index]; } 211 212 // Wrapper functions to support vector iteration begin()213 iterator begin() { return row_.begin(); } 214 begin()215 const_iterator begin() const { return row_.begin(); } 216 end()217 iterator end() { return row_.end(); } 218 end()219 const_iterator end() const { return row_.end(); } 220 221 // Convenience getter functions for flag checking eof()222 bool eof() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOF)); } 223 eoe()224 bool eoe() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOE)); } 225 wait()226 bool wait() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagWait)); } 227 quit()228 bool quit() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagQuit)); } 229 Flags()230 TensorRowFlags Flags() { return tensor_row_flag_; } 231 232 explicit TensorRow(TensorRowFlags); 233 234 protected: 235 row_id_type id_; 236 std::vector<std::string> path_; 237 std::vector<std::shared_ptr<Tensor>> row_; 238 239 TensorRowFlags tensor_row_flag_; 240 241 private: 242 /// Validate data type of TensorRow for conversions. 243 /// \param[in] input TensorRow 244 /// \param[in] data_type data type of the tensor row 245 static Status ValidateTensorRow(const TensorRow &input, const DataType &data_type); 246 }; 247 } // namespace dataset 248 } // namespace mindspore 249 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_ 250