• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_
19 
20 #include <deque>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "minddata/dataset/core/tensor.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 
30 class TensorRow;                             // A set of Tensor pointers with an id
31 using TensorTable = std::vector<TensorRow>;  // The table of tensors is a vector of rows
32 using TensorQTable = std::deque<TensorRow>;  // A different flavour of tensor table, this one has queue functionality
33 
34 class TensorRow {
35  public:
36   static constexpr row_id_type kDefaultRowId = -1;  // Default row id
37 
38   enum TensorRowFlags : uint32_t {
39     kFlagNone = 0,
40     kFlagEOF = 1,         // The row is an eof end-of-data msg
41     kFlagEOE = 1u << 1,   // The row is an eoe end-of-epoch msg
42     kFlagWait = 1u << 2,  // The row is an control signal for workers to suspend operations
43     kFlagQuit = 1u << 3   // The row is a control signal for workers to quit
44   };
45 
46   // Type definitions
47   using size_type = size_t;
48   using value_type = std::shared_ptr<Tensor>;
49   using reference = std::shared_ptr<Tensor> &;
50   using const_reference = const std::shared_ptr<Tensor> &;
51   using vector_type = std::vector<std::shared_ptr<Tensor>>;
52   using iterator = std::vector<std::shared_ptr<Tensor>>::iterator;
53   using const_iterator = std::vector<std::shared_ptr<Tensor>>::const_iterator;
54 
55   TensorRow() noexcept;
56 
57   TensorRow(size_type n, const value_type &t) noexcept;
58 
59   // Copy Constructors
60   explicit TensorRow(const vector_type &v);
61 
62   TensorRow(row_id_type id, const std::initializer_list<value_type> &lst);
63 
64   TensorRow(const TensorRow &tr);
65 
66   TensorRow &operator=(const TensorRow &tr);
67 
68   TensorRow &operator=(const std::initializer_list<value_type> &lst);
69 
70   // Move Constructors
71   explicit TensorRow(vector_type &&v) noexcept;
72 
73   TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept;
74 
75   TensorRow(TensorRow &&tr) noexcept;
76 
77   TensorRow &operator=(TensorRow &&tr) noexcept;
78 
79   TensorRow &operator=(std::initializer_list<value_type> &&lst) noexcept;
80 
81   // Destructor
82   ~TensorRow() = default;
83 
84   /// Convert a vector of primitive types to a TensorRow consisting of one 1-D Tensor with the shape n.
85   /// \tparam `T`
86   /// \param[in] o input vector
87   /// \param[out] output TensorRow
88   template <typename T>
ConvertToTensorRow(const std::vector<T> & o,TensorRow * output)89   static Status ConvertToTensorRow(const std::vector<T> &o, TensorRow *output) {
90     RETURN_UNEXPECTED_IF_NULL(output);
91     DataType data_type = DataType::FromCType<T>();
92     if (data_type == DataType::DE_UNKNOWN) {
93       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
94     }
95     if (data_type == DataType::DE_STRING) {
96       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
97     }
98     std::shared_ptr<Tensor> tensor;
99     RETURN_IF_NOT_OK(Tensor::CreateFromVector(o, &tensor));
100     output->push_back(tensor);
101     return Status::OK();
102   }
103 
104   /// Convert a single primitive type to a TensorRow consisting of one single data Tensor.
105   /// \tparam `T`
106   /// \param[in] o input
107   /// \param[out] output TensorRow
108   template <typename T>
ConvertToTensorRow(const T & o,TensorRow * output)109   static Status ConvertToTensorRow(const T &o, TensorRow *output) {
110     RETURN_UNEXPECTED_IF_NULL(output);
111     DataType data_type = DataType::FromCType<T>();
112     if (data_type == DataType::DE_UNKNOWN) {
113       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
114     }
115     if (data_type == DataType::DE_STRING) {
116       RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
117     }
118     std::shared_ptr<Tensor> tensor;
119     RETURN_IF_NOT_OK(Tensor::CreateScalar(o, &tensor));
120     output->push_back(tensor);
121     return Status::OK();
122   }
123 
124   /// Return the value in a TensorRow consisting of 1 single data Tensor.
125   /// \tparam `T`
126   /// \param[in] input TensorRow
127   /// \param[out] o the primitive variable
128   template <typename T>
ConvertFromTensorRow(const TensorRow & input,T * o)129   static Status ConvertFromTensorRow(const TensorRow &input, T *o) {
130     RETURN_UNEXPECTED_IF_NULL(o);
131     DataType data_type = DataType::FromCType<T>();
132     RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type));
133     if (input.at(0)->type() != data_type) {
134       RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type.");
135     }
136     if (input.at(0)->shape() != TensorShape({})) {
137       RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must be a scalar tensor.");
138     }
139     return input.at(0)->GetItemAt(o, {});
140   }
141 
142   /// Convert a TensorRow consisting of one 1-D tensor to a vector of size n.
143   /// \tparam `T`
144   /// \param[in] o TensorRow consisting of one 1-D tensor
145   /// \param[out] o vector of primitive variable
146   template <typename T>
ConvertFromTensorRow(const TensorRow & input,std::vector<T> * o)147   static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) {
148     RETURN_UNEXPECTED_IF_NULL(o);
149     DataType data_type = DataType::FromCType<T>();
150     RETURN_IF_NOT_OK(ValidateTensorRow(input, data_type));
151     if (input.at(0)->Rank() != 1)
152       RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a rank of 1.");
153     for (auto it = input.at(0)->begin<T>(); it != input.at(0)->end<T>(); it++) {
154       o->push_back(*it);
155     }
156     return Status::OK();
157   }
158 
159   // Functions to fetch/set id/vector
getId()160   row_id_type getId() const { return id_; }
161 
setId(row_id_type id)162   void setId(row_id_type id) { id_ = id; }
163 
getPath()164   std::vector<std::string> getPath() const { return path_; }
165 
setPath(std::vector<std::string> path)166   void setPath(std::vector<std::string> path) { path_ = path; }
167 
getRow()168   const vector_type &getRow() const { return row_; }
169 
SizeInBytes()170   int64_t SizeInBytes() const {
171     size_t sz = 0;
172     for (auto &it : row_) {
173       sz += it->SizeInBytes();
174     }
175     return sz;
176   }
177 
178   // Wrapper functions to support vector operations
emplace_back(value_type t)179   void emplace_back(value_type t) { row_.emplace_back(t); }
180 
push_back(value_type t)181   void push_back(value_type t) { row_.push_back(t); }
182 
clear()183   void clear() noexcept { row_.clear(); }
184 
size()185   size_type size() const noexcept { return row_.size(); }
186 
reserve(size_type size)187   void reserve(size_type size) { row_.reserve(size); }
188 
resize(size_type size)189   void resize(size_type size) { row_.resize(size); }
190 
empty()191   bool empty() { return row_.empty(); }
192 
insert(iterator position,iterator first,iterator last)193   void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); }
194 
195   // Wrapper functions to support vector element access
at(size_type index)196   reference at(size_type index) { return row_.at(index); }
197 
at(size_type index)198   const_reference at(size_type index) const { return row_.at(index); }
199 
front()200   reference front() { return row_.front(); }
201 
front()202   const_reference front() const { return row_.front(); }
203 
back()204   reference back() { return row_.back(); }
205 
back()206   const_reference back() const { return row_.back(); }
207 
208   reference operator[](size_type index) { return row_[index]; }
209 
210   const_reference operator[](size_type index) const { return row_[index]; }
211 
212   // Wrapper functions to support vector iteration
begin()213   iterator begin() { return row_.begin(); }
214 
begin()215   const_iterator begin() const { return row_.begin(); }
216 
end()217   iterator end() { return row_.end(); }
218 
end()219   const_iterator end() const { return row_.end(); }
220 
221   // Convenience getter functions for flag checking
eof()222   bool eof() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOF)); }
223 
eoe()224   bool eoe() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagEOE)); }
225 
wait()226   bool wait() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagWait)); }
227 
quit()228   bool quit() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagQuit)); }
229 
Flags()230   TensorRowFlags Flags() { return tensor_row_flag_; }
231 
232   explicit TensorRow(TensorRowFlags);
233 
234  protected:
235   row_id_type id_;
236   std::vector<std::string> path_;
237   std::vector<std::shared_ptr<Tensor>> row_;
238 
239   TensorRowFlags tensor_row_flag_;
240 
241  private:
242   /// Validate data type of TensorRow for conversions.
243   /// \param[in] input TensorRow
244   /// \param[in] data_type data type of the tensor row
245   static Status ValidateTensorRow(const TensorRow &input, const DataType &data_type);
246 };
247 }  // namespace dataset
248 }  // namespace mindspore
249 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_ROW_H_
250