1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 26 #include "include/api/dual_abi_helper.h" 27 #include "include/api/status.h" 28 #include "include/api/types.h" 29 30 namespace mindspore { 31 namespace dataset { 32 // Forward declare 33 class ExecutionTree; 34 class DatasetOp; 35 class Tensor; 36 37 class NativeRuntimeContext; 38 class IteratorConsumer; 39 class PullBasedIteratorConsumer; 40 41 class Dataset; 42 43 using MSTensorMap = std::unordered_map<std::string, mindspore::MSTensor>; 44 using MSTensorMapChar = std::map<std::vector<char>, mindspore::MSTensor>; 45 using MSTensorVec = std::vector<mindspore::MSTensor>; 46 47 // Abstract class for iterating over the dataset. 48 class DATASET_API Iterator { 49 public: 50 /// \brief Constructor. 51 Iterator(); 52 53 /// \brief Destructor. 54 virtual ~Iterator(); 55 56 /// \brief Method for building and launching the pipeline. 57 /// \param[in] ds The last DatasetOp in the dataset pipeline. 58 /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs). 59 /// \return Status error code, returns OK if no error encountered. 60 virtual Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs); 61 62 /// \brief Function to get the next row from the data pipeline. 63 /// \note Type of return data is a unordered_map(with column name). 64 /// \param[out] row The output tensor row. 65 /// \return Status error code, returns OK if no error encountered. 66 /// \par Example 67 /// \code 68 /// /* dataset is an instance of Dataset object */ 69 /// std::shared_ptr<Iterator> = dataset->CreateIterator(); 70 /// std::unordered_map<std::string, mindspore::MSTensor> row; 71 /// iter->GetNextRow(&row); 72 /// \endcode GetNextRow(MSTensorMap * row)73 Status GetNextRow(MSTensorMap *row) { 74 if (row == nullptr) { 75 return Status(kMDUnexpectedError, "Got nullptr when GetNext row."); 76 } 77 MSTensorMapChar row_; 78 row_.clear(); 79 row->clear(); 80 Status s = GetNextRowCharIF(&row_); 81 TensorMapCharToString(&row_, row); 82 return s; 83 } 84 85 /// \brief Char interface(CharIF) of GetNextRow. 86 /// \note The reason for using this API is that std::string will be constrained by the 87 /// compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction. 88 Status GetNextRowCharIF(MSTensorMapChar *row); 89 90 /// \brief Function to get the next row from the data pipeline. 91 /// \note Type of return data is a vector(without column name). 92 /// \param[out] row The output tensor row. 93 /// \return Status error code, returns OK if no error encountered. 94 /// \par Example 95 /// \code 96 /// /* dataset is an instance of Dataset object */ 97 /// std::shared_ptr<Iterator> = dataset->CreateIterator(); 98 /// std::vector<mindspore::MSTensor> row; 99 /// iter->GetNextRow(&row); 100 /// \endcode 101 virtual Status GetNextRow(MSTensorVec *row); 102 103 /// \brief Function to shut down the data pipeline. 104 void Stop(); 105 106 /// \brief Inter class as iterator of Iterator. 107 class _Iterator { 108 public: 109 /// \brief Constructor 110 explicit _Iterator(Iterator *lt); 111 112 /// \brief Destructor ~_Iterator()113 ~_Iterator() { 114 if (cur_row_ != nullptr) { 115 delete cur_row_; 116 cur_row_ = nullptr; 117 } 118 } 119 120 /// \brief prefix ++ overload 121 _Iterator &operator++(); 122 123 /// \brief dereference operator 124 MSTensorMap &operator*() { return *cur_row_; } 125 126 /// \brief dereference operator 127 MSTensorMap *operator->() { return cur_row_; } 128 129 /// \brief bool operator 130 bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; } 131 132 private: 133 int ind_; // the cur node our Iterator points to 134 Iterator *lt_; 135 MSTensorMap *cur_row_; 136 }; 137 138 /// \brief Function to return the iterator points to the begin of Iterator. begin()139 _Iterator begin() { return _Iterator(this); } 140 141 /// \brief Function to return the iterator points to the end of Iterator. end()142 _Iterator end() { return _Iterator(nullptr); } 143 144 private: 145 std::unique_ptr<NativeRuntimeContext> runtime_context_; 146 IteratorConsumer *consumer_; 147 }; 148 149 class DATASET_API PullIterator : public Iterator { 150 public: 151 /// \brief Constructor. 152 PullIterator(); 153 154 /// \brief Destructor. 155 ~PullIterator() override; 156 157 /// \brief Function to get next row from the data pipeline. 158 /// \note Type of return data is a vector(without column name). 159 /// \param[out] row The output tensor row. 160 /// \return Status error code, returns OK if no error encountered else false. 161 /// \par Example 162 /// \code 163 /// /* dataset is an instance of Dataset object */ 164 /// std::shared_ptr<Iterator> = dataset->CreatePullBasedIterator(); 165 /// std::vector<mindspore::MSTensor> row; 166 /// iter->GetNextRow(&row); 167 /// \endcode 168 Status GetNextRow(MSTensorVec *const row) override; 169 170 /// \brief Function to get specified rows from the data pipeline. 171 /// \note Type of return data is a vector(without column name). This behavior is subject to change. 172 /// \param[in] num_rows The number of rows to fetch. 173 /// \param[out] row The output tensor row. 174 /// \return Status error code, returns OK if no error encountered else false. 175 /// \par Example 176 /// \code 177 /// /* dataset is an instance of Dataset object */ 178 /// std::shared_ptr<Iterator> = dataset->CreatePullBasedIterator(); 179 /// std::vector<std::vector<mindspore::MSTensor>> rows; 180 /// iter->GetNextRow(5, &rows); 181 /// \endcode 182 Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row); 183 184 /// \brief Method for building and launching the pipeline. 185 /// \note Consider making this function protected. 186 /// \param[in] ds The root node that calls the function. 187 /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs). 188 /// \return Status error code, returns OK if no error encountered. 189 Status BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs) override; 190 191 private: 192 std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_; 193 }; 194 } // namespace dataset 195 } // namespace mindspore 196 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_ 197