1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <utility> 23 #include <vector> 24 #include "minddata/dataset/engine/consumers/tree_consumer.h" 25 26 namespace mindspore::dataset { 27 28 /// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict 29 30 class PythonIteratorConsumer : public IteratorConsumer { 31 public: 32 /// Constructor which will call the base class default constructor. 33 /// \param num_epochs number of epochs. Default to -1 (infinite epochs). IteratorConsumer(num_epochs)34 explicit PythonIteratorConsumer(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {} 35 36 ~PythonIteratorConsumer() = default; 37 /// Returns the next row in a vector format 38 /// \param[out] out std::vector of Tensors 39 /// \return Status error code 40 Status GetNextAsList(py::list *out); 41 42 /// Returns the next row in as a map 43 /// \param[out] out std::map of string to Tensor 44 /// \return Status error code 45 Status GetNextAsDict(py::dict *out); 46 }; 47 48 class PythonBuildVocabConsumer : public BuildVocabConsumer { 49 public: 50 Status Start() override; 51 }; 52 53 class PythonSaveToDisk : public SaveToDisk { 54 public: 55 PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType); 56 ~PythonSaveToDisk() = default; 57 Status Save() override; 58 }; 59 60 class PythonTreeGetters : public TreeGetters { 61 public: 62 Status GetRow(TensorRow *const r) override; 63 ~PythonTreeGetters() = default; 64 }; 65 class PythonDatasetSizeGetter : public DatasetSizeGetter { 66 public: 67 Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) override; 68 ~PythonDatasetSizeGetter() = default; 69 }; 70 } // namespace mindspore::dataset 71 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ 72