• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_
19 
20 #include <deque>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "minddata/dataset/core/tensor.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 
30 class TensorRow;                             // A set of Tensor pointers with an id
31 using TensorTable = std::vector<TensorRow>;  // The table of tensors is a vector of rows
32 using TensorQTable = std::deque<TensorRow>;  // A different flavour of tensor table, this one has queue functionality
33 
34 class TensorRow {
35  public:
36   static constexpr row_id_type kDefaultRowId = -1;  // Default row id
37 
38   enum TensorRowFlags : uint32_t {
39     kFlagNone = 0,
40     kFlagEOE = 1U,        // The row is an eoe end-of-epoch msg
41     kFlagEOF = 1U << 1,   // The row is an eof end-of-data msg
42     kFlagWait = 1U << 2,  // The row is an control signal for workers to suspend operations
43     kFlagQuit = 1U << 3,  // The row is a control signal for workers to quit
44     kFlagSkip = 1U << 4,  // The row is a control signal for workers to skip this row
45     kFlagError = 1U << 5  // The row is an error row (needs to be replaced with another row or skipped, as per
46                           //   ErrorSamplesMode config)
47   };
48 
49   // Type definitions
50   using size_type = size_t;
51   using value_type = std::shared_ptr<Tensor>;
52   using reference = std::shared_ptr<Tensor> &;
53   using const_reference = const std::shared_ptr<Tensor> &;
54   using vector_type = std::vector<std::shared_ptr<Tensor>>;
55   using iterator = std::vector<std::shared_ptr<Tensor>>::iterator;
56   using const_iterator = std::vector<std::shared_ptr<Tensor>>::const_iterator;
57 
58   TensorRow() noexcept;
59 
60   TensorRow(size_type n, const value_type &t) noexcept;
61 
62   // Copy Constructors
63   explicit TensorRow(const vector_type &v);
64 
65   TensorRow(row_id_type id, const std::initializer_list<value_type> &lst);
66 
67   TensorRow(const TensorRow &tr);
68 
69   TensorRow &operator=(const TensorRow &tr);
70 
71   TensorRow &operator=(const std::initializer_list<value_type> &lst);
72 
73   // Move Constructors
74   explicit TensorRow(vector_type &&v) noexcept;
75 
76   TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept;
77 
78   TensorRow(TensorRow &&tr) noexcept;
79 
80   TensorRow &operator=(TensorRow &&tr) noexcept;
81 
82   TensorRow &operator=(std::initializer_list<value_type> &&lst) noexcept;
83 
84   // Destructor
85   ~TensorRow() = default;
86 
87   // Deep copy
88   /// \param[in] new_tr the tensor row to clone to
89   Status Clone(TensorRow *new_tr) const;
90 
91   /// Convert a vector of primitive types to a TensorRow consisting of one 1-D Tensor with the shape n.
92   /// \tparam `T`
93   /// \param[in] o input vector
94   /// \param[out] output TensorRow
95   template <typename T>
ConvertToTensorRow(const std::vector<T> & o,TensorRow * output)96   static Status ConvertToTensorRow(const std::vector<T> &o, TensorRow *output) {
97     RETURN_UNEXPECTED_IF_NULL(output);
98     DataType data_type = DataType::FromCType<T>();
99     if (data_type == DataType::DE_UNKNOWN) {
100       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
101     }
102     if (data_type.IsString()) {
103       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string and bytes are not supported.");
104     }
105     std::shared_ptr<Tensor> tensor;
106     RETURN_IF_NOT_OK(Tensor::CreateFromVector(o, &tensor));
107     output->push_back(tensor);
108     return Status::OK();
109   }
110 
111   /// Convert a single primitive type to a TensorRow consisting of one single data Tensor.
112   /// \tparam `T`
113   /// \param[in] o input
114   /// \param[out] output TensorRow
115   template <typename T>
ConvertToTensorRow(const T & o,TensorRow * output)116   static Status ConvertToTensorRow(const T &o, TensorRow *output) {
117     RETURN_UNEXPECTED_IF_NULL(output);
118     DataType data_type = DataType::FromCType<T>();
119     if (data_type == DataType::DE_UNKNOWN) {
120       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
121     }
122     if (data_type.IsString()) {
123       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string and bytes are not supported.");
124     }
125     std::shared_ptr<Tensor> tensor;
126     RETURN_IF_NOT_OK(Tensor::CreateScalar(o, &tensor));
127     output->push_back(tensor);
128     return Status::OK();
129   }
130 
131   /// Return the value in a TensorRow consisting of 1 single data Tensor.
132   /// \tparam `T`
133   /// \param[in] input TensorRow
134   /// \param[out] o the primitive variable
135   template <typename T>
ConvertFromTensorRow(const TensorRow & input,T * o)136   static Status ConvertFromTensorRow(const TensorRow &input, T *o) {
137     RETURN_UNEXPECTED_IF_NULL(o);
138     DataType data_type = DataType::FromCType<T>();
139     RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type));
140     if (input.at(0)->type() != data_type) {
141       RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type.");
142     }
143     if (input.at(0)->shape() != TensorShape({})) {
144       RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must be a scalar tensor.");
145     }
146     return input.at(0)->GetItemAt(o, {});
147   }
148 
149   /// Convert a TensorRow consisting of one 1-D tensor to a vector of size n.
150   /// \tparam `T`
151   /// \param[in] o TensorRow consisting of one 1-D tensor
152   /// \param[out] o vector of primitive variable
153   template <typename T>
ConvertFromTensorRow(const TensorRow & input,std::vector<T> * o)154   static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) {
155     RETURN_UNEXPECTED_IF_NULL(o);
156     DataType data_type = DataType::FromCType<T>();
157     RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type));
158     if (input.at(0)->Rank() != 1) {
159       RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a rank of 1.");
160     }
161     for (auto it = input.at(0)->begin<T>(); it != input.at(0)->end<T>(); it++) {
162       o->push_back(*it);
163     }
164     return Status::OK();
165   }
166 
167   // Functions to fetch/set id/vector
getId()168   row_id_type getId() const { return id_; }
169 
setId(row_id_type id)170   void setId(row_id_type id) { id_ = id; }
171 
getPath()172   std::vector<std::string> getPath() const { return path_; }
173 
setPath(const std::vector<std::string> & path)174   void setPath(const std::vector<std::string> &path) { path_ = path; }
175 
getRow()176   const vector_type &getRow() const { return row_; }
177 
SizeInBytes()178   dsize_t SizeInBytes() const {
179     dsize_t sz = 0;
180     for (auto &it : row_) {
181       sz += it->SizeInBytes();
182     }
183     return sz;
184   }
185 
186   // Wrapper functions to support vector operations
emplace_back(value_type t)187   void emplace_back(value_type t) { row_.emplace_back(t); }
188 
push_back(value_type t)189   void push_back(value_type t) { row_.push_back(t); }
190 
clear()191   void clear() noexcept { row_.clear(); }
192 
193   // Reset both the tensor row vector and flags
reset()194   void reset() noexcept {
195     row_.clear();
196     tensor_row_flag_ = kFlagNone;
197   }
198 
size()199   size_type size() const noexcept { return row_.size(); }
200 
reserve(size_type size)201   void reserve(size_type size) { row_.reserve(size); }
202 
resize(size_type size)203   void resize(size_type size) { row_.resize(size); }
204 
empty()205   const bool empty() { return row_.empty(); }
206 
insert(const_iterator position,const_iterator first,const_iterator last)207   void insert(const_iterator position, const_iterator first, const_iterator last) {
208     row_.insert(position, first, last);
209   }
210 
211   // Wrapper functions to support vector element access
at(size_type index)212   reference at(size_type index) { return row_.at(index); }
213 
at(size_type index)214   const_reference at(size_type index) const { return row_.at(index); }
215 
front()216   reference front() { return row_.front(); }
217 
front()218   const_reference front() const { return row_.front(); }
219 
back()220   reference back() { return row_.back(); }
221 
back()222   const_reference back() const { return row_.back(); }
223 
224   reference operator[](size_type index) { return row_[index]; }
225 
226   const_reference operator[](size_type index) const { return row_[index]; }
227 
228   // Wrapper functions to support vector iteration
begin()229   iterator begin() { return row_.begin(); }
230 
begin()231   const_iterator begin() const { return row_.begin(); }
232 
end()233   iterator end() { return row_.end(); }
234 
end()235   const_iterator end() const { return row_.end(); }
236 
237   // Convenience getter functions for flag checking
eof()238   bool eof() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOF)); }
239 
eoe()240   bool eoe() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOE)); }
241 
wait()242   bool wait() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagWait)); }
243 
quit()244   bool quit() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagQuit)); }
245 
skip()246   bool skip() const {
247     return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagSkip));
248   }
249 
error()250   bool error() const {
251     return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagError));
252   }
253 
Flags()254   TensorRowFlags Flags() const { return tensor_row_flag_; }
255 
FlagName()256   std::string FlagName() const {
257     switch (tensor_row_flag_) {
258       case TensorRowFlags::kFlagNone:
259         return "Data";
260       case TensorRowFlags::kFlagEOE:
261         return "EOE";
262       case TensorRowFlags::kFlagEOF:
263         return "EOF";
264       case TensorRowFlags::kFlagWait:
265         return "Wait";
266       case TensorRowFlags::kFlagQuit:
267         return "Quit";
268       case TensorRowFlags::kFlagSkip:
269         return "Skip";
270       case TensorRowFlags::kFlagError:
271         return "Error";
272       default:
273         return "Unknown";
274     }
275   }
276 
277   explicit TensorRow(TensorRowFlags flag);
278 
279  protected:
280   row_id_type id_;
281   std::vector<std::string> path_;
282   std::vector<std::shared_ptr<Tensor>> row_;
283 
284   TensorRowFlags tensor_row_flag_;
285 
286  private:
287   /// Validate data type of TensorRow for conversions.
288   /// \param[in] input TensorRow
289   /// \param[in] data_type data type of the tensor row
290   static Status ValidateTensorRow(const TensorRow &input, const DataType &data_type);
291 };
292 }  // namespace dataset
293 }  // namespace mindspore
294 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_
295