• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "include/api/dual_abi_helper.h"
27 #include "include/api/status.h"
28 #include "include/api/types.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 // Forward declare
33 class ExecutionTree;
34 class DatasetOp;
35 class Tensor;
36 
37 class NativeRuntimeContext;
38 class IteratorConsumer;
39 class PullBasedIteratorConsumer;
40 
41 class Dataset;
42 
43 using MSTensorMap = std::unordered_map<std::string, mindspore::MSTensor>;
44 using MSTensorMapChar = std::map<std::vector<char>, mindspore::MSTensor>;
45 using MSTensorVec = std::vector<mindspore::MSTensor>;
46 
47 // Abstract class for iterating over the dataset.
48 class DATASET_API Iterator {
49  public:
50   /// \brief Constructor.
51   Iterator();
52 
53   /// \brief Destructor.
54   virtual ~Iterator();
55 
56   /// \brief Method for building and launching the pipeline.
57   /// \param[in] ds The last DatasetOp in the dataset pipeline.
58   /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs).
59   /// \return Status error code, returns OK if no error encountered.
60   virtual Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs);
61 
62   /// \brief Function to get the next row from the data pipeline.
63   /// \note Type of return data is a unordered_map(with column name).
64   /// \param[out] row The output tensor row.
65   /// \return Status error code, returns OK if no error encountered.
66   /// \par Example
67   /// \code
68   ///      /* dataset is an instance of Dataset object */
69   ///      std::shared_ptr<Iterator> = dataset->CreateIterator();
70   ///      std::unordered_map<std::string, mindspore::MSTensor> row;
71   ///      iter->GetNextRow(&row);
72   /// \endcode
GetNextRow(MSTensorMap * row)73   Status GetNextRow(MSTensorMap *row) {
74     if (row == nullptr) {
75       return Status(kMDUnexpectedError, "Got nullptr when GetNext row.");
76     }
77     MSTensorMapChar row_;
78     row_.clear();
79     row->clear();
80     Status s = GetNextRowCharIF(&row_);
81     TensorMapCharToString(&row_, row);
82     return s;
83   }
84 
85   /// \brief Char interface(CharIF) of GetNextRow.
86   /// \note The reason for using this API is that std::string will be constrained by the
87   ///    compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction.
88   Status GetNextRowCharIF(MSTensorMapChar *row);
89 
90   /// \brief Function to get the next row from the data pipeline.
91   /// \note Type of return data is a vector(without column name).
92   /// \param[out] row The output tensor row.
93   /// \return Status error code, returns OK if no error encountered.
94   /// \par Example
95   /// \code
96   ///      /* dataset is an instance of Dataset object */
97   ///      std::shared_ptr<Iterator> = dataset->CreateIterator();
98   ///      std::vector<mindspore::MSTensor> row;
99   ///      iter->GetNextRow(&row);
100   /// \endcode
101   virtual Status GetNextRow(MSTensorVec *row);
102 
103   /// \brief Function to shut down the data pipeline.
104   void Stop();
105 
106   /// \brief Inter class as iterator of Iterator.
107   class _Iterator {
108    public:
109     /// \brief Constructor
110     explicit _Iterator(Iterator *lt);
111 
112     /// \brief Destructor
~_Iterator()113     ~_Iterator() {
114       if (cur_row_ != nullptr) {
115         delete cur_row_;
116         cur_row_ = nullptr;
117       }
118     }
119 
120     /// \brief prefix ++ overload
121     _Iterator &operator++();
122 
123     /// \brief dereference operator
124     MSTensorMap &operator*() { return *cur_row_; }
125 
126     /// \brief dereference operator
127     MSTensorMap *operator->() { return cur_row_; }
128 
129     /// \brief bool operator
130     bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; }
131 
132    private:
133     int ind_;  // the cur node our Iterator points to
134     Iterator *lt_;
135     MSTensorMap *cur_row_;
136   };
137 
138   /// \brief Function to return the iterator points to the begin of Iterator.
begin()139   _Iterator begin() { return _Iterator(this); }
140 
141   /// \brief Function to return the iterator points to the end of Iterator.
end()142   _Iterator end() { return _Iterator(nullptr); }
143 
144  private:
145   std::unique_ptr<NativeRuntimeContext> runtime_context_;
146   IteratorConsumer *consumer_;
147 };
148 
149 class DATASET_API PullIterator : public Iterator {
150  public:
151   /// \brief Constructor.
152   PullIterator();
153 
154   /// \brief Destructor.
155   ~PullIterator() override;
156 
157   /// \brief Function to get next row from the data pipeline.
158   /// \note Type of return data is a vector(without column name).
159   /// \param[out] row The output tensor row.
160   /// \return Status error code, returns OK if no error encountered else false.
161   /// \par Example
162   /// \code
163   ///      /* dataset is an instance of Dataset object */
164   ///      std::shared_ptr<Iterator> = dataset->CreatePullBasedIterator();
165   ///      std::vector<mindspore::MSTensor> row;
166   ///      iter->GetNextRow(&row);
167   /// \endcode
168   Status GetNextRow(MSTensorVec *const row) override;
169 
170   /// \brief Function to get specified rows from the data pipeline.
171   /// \note Type of return data is a vector(without column name). This behavior is subject to change.
172   /// \param[in] num_rows The number of rows to fetch.
173   /// \param[out] row The output tensor row.
174   /// \return Status error code, returns OK if no error encountered else false.
175   /// \par Example
176   /// \code
177   ///      /* dataset is an instance of Dataset object */
178   ///      std::shared_ptr<Iterator> = dataset->CreatePullBasedIterator();
179   ///      std::vector<std::vector<mindspore::MSTensor>> rows;
180   ///      iter->GetNextRow(5, &rows);
181   /// \endcode
182   Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row);
183 
184   /// \brief Method for building and launching the pipeline.
185   /// \note Consider making this function protected.
186   /// \param[in] ds The root node that calls the function.
187   /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs).
188   /// \return Status error code, returns OK if no error encountered.
189   Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs) override;
190 
191  private:
192   std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
193 };
194 }  // namespace dataset
195 }  // namespace mindspore
196 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
197