1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PULL_BASED_TREE_CONSUMER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PULL_BASED_TREE_CONSUMER_H_ 18 19 #include <cstddef> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <utility> 24 #include <vector> 25 26 #include "minddata/dataset/engine/consumers/tree_consumer.h" 27 #include "minddata/dataset/engine/tree_adapter_lite.h" 28 29 namespace mindspore::dataset { 30 class TreeAdapterLite; 31 class TensorRow; 32 33 /// Consumer that iterates over the dataset and returns the rows one by one as a in a pull based fashion 34 class PullBasedIteratorConsumer : public TreeConsumer { 35 public: 36 /// Constructor 37 /// \param num_epochs number of epochs. Default: 1. 38 explicit PullBasedIteratorConsumer(int32_t num_epochs = 1) TreeConsumer(num_epochs)39 : TreeConsumer(num_epochs), tree_adapter_lite_(std::make_unique<TreeAdapterLite>()) {} 40 41 ~PullBasedIteratorConsumer() override = default; 42 43 Status Init(const std::shared_ptr<DatasetNode> &root) override; 44 45 /// \brief Returns the next row in a vector format 46 /// \note This is currently a placeholder function 47 /// \param[in] num_rows the number of rows that we want to get 48 /// \return out std::vector of TensorRows 49 std::vector<TensorRow> GetRows(int64_t num_rows); 50 51 /// Returns the next row in a vector format 52 /// \param[out] out std::vector of Tensors 53 /// \return Status error code 54 Status GetNextAsVector(std::vector<TensorPtr> *const out) override; 55 56 /// Returns the next row in as a map 57 /// \param[out] out std::map of string to Tensor 58 /// \return Status error code 59 Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out) override; 60 61 /// Returns the next row in as a vector 62 /// \param[out] vec std::vector of pairs of string to Tensor 63 /// \return Status error code 64 Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec) override; 65 66 /// Function to reset the current consumer to the provided step. 67 /// \note Reset is NOT supported for pull-based iterators. 68 /// \param step the step to reset the pipeline to. 69 /// \param dataset_size the number of steps that one epoch has. 70 /// \return Status error code Reset(int64_t step,int64_t dataset_size)71 Status Reset(int64_t step, int64_t dataset_size) override { 72 RETURN_STATUS_UNEXPECTED( 73 "Failover reset is not supported for pull-based iterators (including when Debug mode is enabled)."); 74 } 75 76 protected: 77 /// Method to return the name of the consumer 78 /// \return string Name()79 std::string Name() override { return "PullBasedIteratorConsumer"; } 80 std::unique_ptr<TreeAdapterLite> tree_adapter_lite_; 81 82 private: 83 std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id 84 }; 85 86 /// Consumer that is used to get some pipeline information 87 class TreeGetters : public PullBasedIteratorConsumer { 88 public: 89 TreeGetters(); 90 91 ~TreeGetters() override = default; 92 93 Status Init(const std::shared_ptr<DatasetNode> &root) override; 94 95 Status GetOutputTypes(std::vector<DataType> *types); 96 97 Status GetOutputShapes(std::vector<TensorShape> *shapes, bool estimate = false); 98 99 Status GetBatchSize(int64_t *batch_size); 100 101 Status GetRepeatCount(int64_t *repeat_count); 102 103 Status GetNumClasses(int64_t *num_classes); 104 105 Status GetColumnNames(std::vector<std::string> *output); 106 107 Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); 108 Name()109 std::string Name() override { return "TreeGetters"; } 110 111 virtual Status GetRow(TensorRow *row); 112 113 private: 114 Status GetFirstRowShapeAndType(); 115 116 std::shared_ptr<DatasetNode> root_; 117 std::vector<DataType> first_row_type_; 118 std::vector<TensorShape> first_row_shape_; 119 std::vector<TensorShape> estimated_row_shape_; 120 bool first_row_obtained_; // whether first row (which could be empty) is obtained by TreeGetter 121 bool init_flag_; // indicate whether the tree has initialized 122 123 Status InternalInit(); 124 }; 125 } // namespace mindspore::dataset 126 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PULL_BASED_TREE_CONSUMER_H_ 127