1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 23 #include "pybind11/pybind11.h" 24 25 #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h" 26 #include "minddata/dataset/engine/consumers/tree_consumer.h" 27 28 namespace mindspore::dataset { 29 30 /// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict 31 32 class PythonIteratorConsumer : public IteratorConsumer { 33 public: 34 /// Constructor which will call the base class default constructor. 35 /// \param num_epochs number of epochs. Default to -1 (infinite epochs). IteratorConsumer(num_epochs)36 explicit PythonIteratorConsumer(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {} 37 38 ~PythonIteratorConsumer() = default; 39 40 /// Returns the next row in a vector format 41 /// \param[out] out std::vector of Tensors 42 /// \return Status error code 43 Status GetNextAsList(const py::list *out); 44 45 /// Returns the next row in as a map 46 /// \param[out] out std::map of string to Tensor 47 /// \return Status error code 48 Status GetNextAsDict(const py::dict *out); 49 }; 50 51 class PythonPullBasedIteratorConsumer : public PullBasedIteratorConsumer { 52 public: 53 /// Constructor which will call the base class default constructor. 54 /// \param num_epochs number of epochs. Default to -1 (infinite epochs). PullBasedIteratorConsumer(num_epochs)55 explicit PythonPullBasedIteratorConsumer(int32_t num_epochs = -1) : PullBasedIteratorConsumer(num_epochs) {} 56 57 ~PythonPullBasedIteratorConsumer() = default; 58 59 /// Returns the next row in a vector format 60 /// \param[out] out std::vector of Tensors 61 /// \return Status error code 62 Status GetNextAsList(const py::list *out); 63 64 /// Returns the next row in as a map 65 /// \param[out] out std::map of string to Tensor 66 /// \return Status error code 67 Status GetNextAsDict(const py::dict *out); 68 }; 69 70 class PythonBuildVocabConsumer : public BuildVocabConsumer { 71 public: 72 Status Start() override; 73 }; 74 75 class PythonSaveToDisk : public SaveToDisk { 76 public: 77 PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType); 78 79 ~PythonSaveToDisk() = default; 80 81 Status Save() override; 82 }; 83 84 class PythonTreeGetters : public TreeGetters { 85 public: 86 Status GetRow(TensorRow *const r) override; 87 88 ~PythonTreeGetters() = default; 89 }; 90 91 class PythonDatasetSizeGetter : public DatasetSizeGetter { 92 public: 93 Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) override; 94 95 ~PythonDatasetSizeGetter() = default; 96 }; 97 } // namespace mindspore::dataset 98 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ 99