• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 #include "include/api/dual_abi_helper.h"
26 #include "include/api/status.h"
27 #include "include/api/types.h"
28 
29 namespace mindspore {
30 namespace dataset {
31 
32 // Forward declare
33 class ExecutionTree;
34 class DatasetOp;
35 class Tensor;
36 
37 class NativeRuntimeContext;
38 class IteratorConsumer;
39 class PullBasedIteratorConsumer;
40 
41 class Dataset;
42 
43 using MSTensorMap = std::unordered_map<std::string, mindspore::MSTensor>;
44 using MSTensorMapChar = std::map<std::vector<char>, mindspore::MSTensor>;
45 using MSTensorVec = std::vector<mindspore::MSTensor>;
46 
47 // Abstract class for iterating over the dataset.
48 class Iterator {
49  public:
50   /// \brief Constructor.
51   Iterator();
52 
53   /// \brief Destructor.
54   ~Iterator();
55 
56   /// \brief Method for building and launching the pipeline.
57   /// \param[in] ds The last DatasetOp in the dataset pipeline.
58   /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs).
59   /// \return Status error code, returns OK if no error encountered.
60   Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs);
61 
62   /// \brief Function to get the next row from the data pipeline.
63   /// \note Type of return data is a unordered_map(with column name).
64   /// \param[out] row The output tensor row.
65   /// \return Status error code, returns OK if no error encountered.
GetNextRow(MSTensorMap * row)66   Status GetNextRow(MSTensorMap *row) {
67     if (row == nullptr) {
68       return Status(kMDUnexpectedError, "Got nullptr when GetNext row.");
69     }
70     MSTensorMapChar row_;
71     row_.clear();
72     row->clear();
73     Status s = GetNextRowCharIF(&row_);
74     TensorMapCharToString(&row_, row);
75     return s;
76   }
77 
78   /// \brief Char interface(CharIF) of GetNextRow.
79   /// \note The reason for using this API is that std::string will be constrained by the
80   ///    compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction.
81   Status GetNextRowCharIF(MSTensorMapChar *row);
82 
83   /// \brief Function to get the next row from the data pipeline.
84   /// \note Type of return data is a vector(without column name).
85   /// \param[out] row The output tensor row.
86   /// \return Status error code, returns OK if no error encountered.
87   virtual Status GetNextRow(MSTensorVec *row);
88 
89   /// \brief Function to shut down the data pipeline.
90   void Stop();
91 
92   /// \brief Inter class as iterator of Iterator.
93   class _Iterator {
94    public:
95     /// \brief Constructor
96     explicit _Iterator(Iterator *lt);
97 
98     /// \brief Destructor
~_Iterator()99     ~_Iterator() {
100       if (cur_row_ != nullptr) {
101         delete cur_row_;
102         cur_row_ = nullptr;
103       }
104     }
105 
106     /// \brief prefix ++ overload
107     _Iterator &operator++();
108 
109     /// \brief dereference operator
110     MSTensorMap &operator*() { return *cur_row_; }
111 
112     /// \brief dereference operator
113     MSTensorMap *operator->() { return cur_row_; }
114 
115     /// \brief bool operator
116     bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; }
117 
118    private:
119     int ind_;  // the cur node our Iterator points to
120     Iterator *lt_;
121     MSTensorMap *cur_row_;
122   };
123 
124   /// \brief Function to return the iterator points to the begin of Iterator.
begin()125   _Iterator begin() { return _Iterator(this); }
126 
127   /// \brief Function to return the iterator points to the end of Iterator.
end()128   _Iterator end() { return _Iterator(nullptr); }
129 
130  private:
131   std::unique_ptr<NativeRuntimeContext> runtime_context_;
132   IteratorConsumer *consumer_;
133 };
134 
135 class PullIterator : public Iterator {
136  public:
137   /// \brief Constructor.
138   PullIterator();
139 
140   /// \brief Destructor.
141   ~PullIterator() = default;
142 
143   /// \brief Function to get next row from the data pipeline.
144   /// \note Type of return data is a vector(without column name).
145   /// \param[out] row The output tensor row.
146   /// \return Status error code, returns OK if no error encountered else false.
147   Status GetNextRow(MSTensorVec *const row) override;
148 
149   /// \brief Function to get specified rows from the data pipeline.
150   /// \note Type of return data is a vector(without column name). This behavior is subject to change.
151   /// \param[in] num_rows The number of rows to fetch.
152   /// \param[out] row The output tensor row.
153   /// \return Status error code, returns OK if no error encountered else false.
154   Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row);
155 
156   /// \brief Method for building and launching the pipeline.
157   /// \note Consider making this function protected.
158   /// \param[in] ds The root node that calls the function.
159   /// \return Status error code, returns OK if no error encountered.
160   Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
161 
162  private:
163   std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
164 };
165 }  // namespace dataset
166 }  // namespace mindspore
167 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
168