1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 #include "include/api/dual_abi_helper.h" 26 #include "include/api/status.h" 27 #include "include/api/types.h" 28 29 namespace mindspore { 30 namespace dataset { 31 32 // Forward declare 33 class ExecutionTree; 34 class DatasetOp; 35 class Tensor; 36 37 class NativeRuntimeContext; 38 class IteratorConsumer; 39 class PullBasedIteratorConsumer; 40 41 class Dataset; 42 43 using MSTensorMap = std::unordered_map<std::string, mindspore::MSTensor>; 44 using MSTensorMapChar = std::map<std::vector<char>, mindspore::MSTensor>; 45 using MSTensorVec = std::vector<mindspore::MSTensor>; 46 47 // Abstract class for iterating over the dataset. 48 class Iterator { 49 public: 50 /// \brief Constructor. 51 Iterator(); 52 53 /// \brief Destructor. 54 ~Iterator(); 55 56 /// \brief Method for building and launching the pipeline. 57 /// \param[in] ds The last DatasetOp in the dataset pipeline. 58 /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode (default=-1, which means infinite epochs). 59 /// \return Status error code, returns OK if no error encountered. 60 Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs); 61 62 /// \brief Function to get the next row from the data pipeline. 63 /// \note Type of return data is a unordered_map(with column name). 64 /// \param[out] row The output tensor row. 65 /// \return Status error code, returns OK if no error encountered. GetNextRow(MSTensorMap * row)66 Status GetNextRow(MSTensorMap *row) { 67 if (row == nullptr) { 68 return Status(kMDUnexpectedError, "Got nullptr when GetNext row."); 69 } 70 MSTensorMapChar row_; 71 row_.clear(); 72 row->clear(); 73 Status s = GetNextRowCharIF(&row_); 74 TensorMapCharToString(&row_, row); 75 return s; 76 } 77 78 /// \brief Char interface(CharIF) of GetNextRow. 79 /// \note The reason for using this API is that std::string will be constrained by the 80 /// compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction. 81 Status GetNextRowCharIF(MSTensorMapChar *row); 82 83 /// \brief Function to get the next row from the data pipeline. 84 /// \note Type of return data is a vector(without column name). 85 /// \param[out] row The output tensor row. 86 /// \return Status error code, returns OK if no error encountered. 87 virtual Status GetNextRow(MSTensorVec *row); 88 89 /// \brief Function to shut down the data pipeline. 90 void Stop(); 91 92 /// \brief Inter class as iterator of Iterator. 93 class _Iterator { 94 public: 95 /// \brief Constructor 96 explicit _Iterator(Iterator *lt); 97 98 /// \brief Destructor ~_Iterator()99 ~_Iterator() { 100 if (cur_row_ != nullptr) { 101 delete cur_row_; 102 cur_row_ = nullptr; 103 } 104 } 105 106 /// \brief prefix ++ overload 107 _Iterator &operator++(); 108 109 /// \brief dereference operator 110 MSTensorMap &operator*() { return *cur_row_; } 111 112 /// \brief dereference operator 113 MSTensorMap *operator->() { return cur_row_; } 114 115 /// \brief bool operator 116 bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; } 117 118 private: 119 int ind_; // the cur node our Iterator points to 120 Iterator *lt_; 121 MSTensorMap *cur_row_; 122 }; 123 124 /// \brief Function to return the iterator points to the begin of Iterator. begin()125 _Iterator begin() { return _Iterator(this); } 126 127 /// \brief Function to return the iterator points to the end of Iterator. end()128 _Iterator end() { return _Iterator(nullptr); } 129 130 private: 131 std::unique_ptr<NativeRuntimeContext> runtime_context_; 132 IteratorConsumer *consumer_; 133 }; 134 135 class PullIterator : public Iterator { 136 public: 137 /// \brief Constructor. 138 PullIterator(); 139 140 /// \brief Destructor. 141 ~PullIterator() = default; 142 143 /// \brief Function to get next row from the data pipeline. 144 /// \note Type of return data is a vector(without column name). 145 /// \param[out] row The output tensor row. 146 /// \return Status error code, returns OK if no error encountered else false. 147 Status GetNextRow(MSTensorVec *const row) override; 148 149 /// \brief Function to get specified rows from the data pipeline. 150 /// \note Type of return data is a vector(without column name). This behavior is subject to change. 151 /// \param[in] num_rows The number of rows to fetch. 152 /// \param[out] row The output tensor row. 153 /// \return Status error code, returns OK if no error encountered else false. 154 Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row); 155 156 /// \brief Method for building and launching the pipeline. 157 /// \note Consider making this function protected. 158 /// \param[in] ds The root node that calls the function. 159 /// \return Status error code, returns OK if no error encountered. 160 Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds); 161 162 private: 163 std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_; 164 }; 165 } // namespace dataset 166 } // namespace mindspore 167 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_ 168