• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PULL_BASED_TREE_CONSUMER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PULL_BASED_TREE_CONSUMER_H_
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "minddata/dataset/engine/consumers/tree_consumer.h"
27 #include "minddata/dataset/engine/tree_adapter_lite.h"
28 
29 namespace mindspore::dataset {
30 class TreeAdapterLite;
31 class TensorRow;
32 
33 /// Consumer that iterates over the dataset and returns the rows one by one as a in a pull based fashion
34 class PullBasedIteratorConsumer : public TreeConsumer {
35  public:
36   /// Constructor
37   /// \param num_epochs number of epochs. Default: 1.
38   explicit PullBasedIteratorConsumer(int32_t num_epochs = 1)
TreeConsumer(num_epochs)39       : TreeConsumer(num_epochs), tree_adapter_lite_(std::make_unique<TreeAdapterLite>()) {}
40 
41   ~PullBasedIteratorConsumer() override = default;
42 
43   Status Init(const std::shared_ptr<DatasetNode> &root) override;
44 
45   /// \brief Returns the next row in a vector format
46   /// \note This is currently a placeholder function
47   /// \param[in] num_rows the number of rows that we want to get
48   /// \return out std::vector of TensorRows
49   std::vector<TensorRow> GetRows(int64_t num_rows);
50 
51   /// Returns the next row in a vector format
52   /// \param[out] out std::vector of Tensors
53   /// \return Status error code
54   Status GetNextAsVector(std::vector<TensorPtr> *const out) override;
55 
56   /// Returns the next row in as a map
57   /// \param[out] out std::map of string to Tensor
58   /// \return Status error code
59   Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out) override;
60 
61   /// Returns the next row in as a vector
62   /// \param[out] vec std::vector of pairs of string to Tensor
63   /// \return Status error code
64   Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec) override;
65 
66   /// Function to reset the current consumer to the provided step.
67   /// \note Reset is NOT supported for pull-based iterators.
68   /// \param step the step to reset the pipeline to.
69   /// \param dataset_size the number of steps that one epoch has.
70   /// \return Status error code
Reset(int64_t step,int64_t dataset_size)71   Status Reset(int64_t step, int64_t dataset_size) override {
72     RETURN_STATUS_UNEXPECTED(
73       "Failover reset is not supported for pull-based iterators (including when Debug mode is enabled).");
74   }
75 
76  protected:
77   /// Method to return the name of the consumer
78   /// \return string
Name()79   std::string Name() override { return "PullBasedIteratorConsumer"; }
80   std::unique_ptr<TreeAdapterLite> tree_adapter_lite_;
81 
82  private:
83   std::vector<std::pair<std::string, int32_t>> column_order_;  // key: column name, val: column id
84 };
85 
86 /// Consumer that is used to get some pipeline information
87 class TreeGetters : public PullBasedIteratorConsumer {
88  public:
89   TreeGetters();
90 
91   ~TreeGetters() override = default;
92 
93   Status Init(const std::shared_ptr<DatasetNode> &root) override;
94 
95   Status GetOutputTypes(std::vector<DataType> *types);
96 
97   Status GetOutputShapes(std::vector<TensorShape> *shapes, bool estimate = false);
98 
99   Status GetBatchSize(int64_t *batch_size);
100 
101   Status GetRepeatCount(int64_t *repeat_count);
102 
103   Status GetNumClasses(int64_t *num_classes);
104 
105   Status GetColumnNames(std::vector<std::string> *output);
106 
107   Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
108 
Name()109   std::string Name() override { return "TreeGetters"; }
110 
111   virtual Status GetRow(TensorRow *row);
112 
113  private:
114   Status GetFirstRowShapeAndType();
115 
116   std::shared_ptr<DatasetNode> root_;
117   std::vector<DataType> first_row_type_;
118   std::vector<TensorShape> first_row_shape_;
119   std::vector<TensorShape> estimated_row_shape_;
120   bool first_row_obtained_;  // whether first row (which could be empty) is obtained by TreeGetter
121   bool init_flag_;           // indicate whether the tree has initialized
122 
123   Status InternalInit();
124 };
125 }  // namespace mindspore::dataset
126 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PULL_BASED_TREE_CONSUMER_H_
127