• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <map>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "minddata/dataset/engine/tree_adapter.h"
27 #include "minddata/dataset/text/vocab.h"
28 
29 namespace mindspore::dataset {
30 // Forward declare
31 class TreeAdapter;
32 class DatasetNode;
33 
34 /// A base class for tree consumers which would fetch rows from the tree pipeline
35 class TreeConsumer {
36  public:
37   /// Constructor that prepares an empty tree_adapter
38   TreeConsumer();
39 
40   /// \brief Destructor
41   ~TreeConsumer() = default;
42   /// Initializes the consumer, this involves constructing and preparing the tree.
43   /// \param d The dataset node that represent the root of the IR tree.
44   /// \return Status error code.
45   virtual Status Init(std::shared_ptr<DatasetNode> d);
46 
47   /// Internal function to perform the termination
48   /// \return Status error code
49   virtual Status Terminate();
50 
51  protected:
52   /// The class owns the tree_adapter that handles execution tree operations.
53   std::unique_ptr<TreeAdapter> tree_adapter_;
54   /// Method to return the name of the consumer
55   /// \return string
56   virtual std::string Name() = 0;
57 };
58 
59 /// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map
60 class IteratorConsumer : public TreeConsumer {
61  public:
62   /// Constructor which will call the base class default constructor.
63   /// \param num_epochs number of epochs. Default to -1 (infinite epochs).
TreeConsumer()64   explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
65 
66   ~IteratorConsumer() = default;
67 
68   Status Init(std::shared_ptr<DatasetNode> d) override;
69 
70   /// Returns the next row in a vector format
71   /// \param[out] out std::vector of Tensors
72   /// \return Status error code
73   Status GetNextAsVector(std::vector<TensorPtr> *out);
74 
75   /// Returns the next row in as a map
76   /// \param[out] out std::map of string to Tensor
77   /// \return Status error code
78   Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out);
79 
80   /// Returns the next row in as a vector
81   /// \param[out] out std::vector of pairs of string to Tensor
82   /// \return Status error code
83   Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec);
84 
85  protected:
86   /// Method to return the name of the consumer
87   /// \return string
Name()88   std::string Name() override { return "IteratorConsumer"; }
89 
90  private:
91   int32_t num_epochs_;
92   std::map<int32_t, std::string> column_order_;  // key: column id, val: column name
93 };
94 
95 #ifndef ENABLE_ANDROID
96 /// Consumer that iterates over the dataset and writes it to disk
97 class SaveToDisk : public TreeConsumer {
98  public:
99   /// Constructor which will call the base class default constructor.
100   /// \param dataset_path path the the dataset
101   /// \param num_files number of files. Default to 1
102   /// \param dataset_type The format of the dataset. Default to "mindrecod".
103   explicit SaveToDisk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord")
TreeConsumer()104       : TreeConsumer(), dataset_path_(dataset_path), num_files_(num_files), dataset_type_(dataset_type) {}
105 
106   ~SaveToDisk() = default;
107 
108   /// \brief Parameters validation
109   /// \return Status Status::OK() if all the parameters are valid
110   Status ValidateParams();
111 
112   /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
113   /// would be written to disk)
114   /// \return  Status error code
115   virtual Status Save();
116 
117  protected:
118   /// Method to return the name of the consumer
119   /// \return string
Name()120   std::string Name() override { return "SaveToDisk"; }
121 
122  private:
123   template <typename T, typename S>
124   Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
125                          std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
126                          std::unique_ptr<S> *s, bool need_convert = false);
127 
128   Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
129                                 const TensorRow &row, nlohmann::json *schema, std::vector<std::string> *index_fields);
130 
131   Status FetchDataFromTensorRow(const TensorRow &row,
132                                 const std::unordered_map<std::string, int32_t> &column_name_id_map,
133                                 nlohmann::json *row_raw_data,
134                                 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
135 
136   Status FetchFloatData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
137                         std::unique_ptr<std::vector<uint8_t>> *data_ptr);
138 
139   Status FetchItemData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data,
140                        std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
141 
142   template <typename T>
143   bool map_compare(T const &lhs, T const &rhs);
144 
145   Status CheckTensorRowShapes(const std::unordered_map<std::string, int32_t> &column_name_id_map, const TensorRow &row,
146                               std::map<std::string, std::vector<int>> *PreTensorRowShapes_ptr);
147 
148   std::string dataset_path_;
149   int32_t num_files_;
150   std::string dataset_type_;
151 };
152 #endif
153 
154 /// Consumer that iterates over the dataset and send it to a device
155 class ToDevice : public TreeConsumer {
156  public:
TreeConsumer()157   explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
158 
159   ~ToDevice() = default;
160 
161   Status Init(std::shared_ptr<DatasetNode> d) override;
162 
163   Status Terminate() override;
164 
165   /// Send the data to device
166   /// \return  Status error code
167   virtual Status Send();
168 
169   /// Stop to send data to device
170   /// \return  Status error code
171   virtual Status Stop();
172 
173   /// Continue to send data to device
174   /// \return  Status error code
175   virtual Status Continue();
176 
177   /// Get data info from TDT
178   /// \return  Status error code
179   virtual Status GetDataInfo(std::vector<DataType> *const types, std::vector<TensorShape> *const shapes);
180 
181  protected:
182   /// Method to return the name of the consumer
183   /// \return string
Name()184   std::string Name() override { return "ToDevice"; }
185 
186  private:
187   int32_t num_epochs_;
188 };
189 
190 /// Consumer that is used to get some pipeline information
191 class TreeGetters : public TreeConsumer {
192  public:
193   TreeGetters();
194   ~TreeGetters() = default;
195   Status Init(std::shared_ptr<DatasetNode> d) override;
196 
197   Status GetOutputTypes(std::vector<DataType> *types);
198   Status GetOutputShapes(std::vector<TensorShape> *shapes);
199   Status GetBatchSize(int64_t *batch_size);
200   Status GetRepeatCount(int64_t *repeat_count);
201   Status GetNumClasses(int64_t *num_classes);
202   Status GetColumnNames(std::vector<std::string> *output);
203   Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
Name()204   std::string Name() override { return "TreeGetters"; }
205   virtual Status GetRow(TensorRow *row);
206 
207  private:
208   Status GetFirstRowShapeAndType();
209 
210   std::shared_ptr<DatasetNode> root_;
211   int64_t dataset_size_;
212   std::vector<DataType> first_row_type_;
213   std::vector<TensorShape> first_row_shape_;
214   bool first_row_obtained_;  // whether first row (which could be empty) is obtained by TreeGetter
215   bool init_flag_;           // indicate whether the tree has initialized
216 
217   Status InternalInit();
218 };
219 
220 /// Consumer that is used to get some pipeline information
221 class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_this<DatasetSizeGetter> {
222  public:
DatasetSizeGetter()223   DatasetSizeGetter() : dataset_size_(-1) {}
224   ~DatasetSizeGetter() = default;
225   Status Init(std::shared_ptr<DatasetNode> d) override;
226   Status Terminate() override;
227 
228   /// \brief Function to get the dataset size
229   /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
230   ///     dataset size at the expense of accuracy.
231   /// \param[out] dataset_size the size of the dataset
232   /// \return Status of the function
233   Status GetDatasetSize(int64_t *size, bool estimate = false);
234 
235   virtual Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row);
Name()236   std::string Name() override { return "DatasetSizeGetter"; }
237 
238   /// \brief Gets the dataset size by iterating over the entire dataset on a sub tree starting from ir_node
239   /// param[in] ir_node The node that marks the top most of the sub tree on which we want to iterate
240   /// \return Status - The status code return
241   Status DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size);
242 
243  private:
244   std::shared_ptr<DatasetNode> root_;
245   std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_;  // this is vector to handle different branch of zip
246   int64_t dataset_size_;
247 };
248 
249 class BuildVocabConsumer : public TreeConsumer {
250  public:
251   /// BuildVocabConsumer Constructor which will call the base class default constructor.
252   BuildVocabConsumer() = default;
253 
254   ~BuildVocabConsumer() = default;
255 
256   Status Init(std::shared_ptr<DatasetNode> d) override;
257 
258   /// Start consuming
259   /// \return  Status error code
260   virtual Status Start();
261 
262  protected:
263   /// Method to return the name of the consumer
264   /// \return string
Name()265   std::string Name() override { return "BuildVocab"; }
266 };
267 
268 }  // namespace mindspore::dataset
269 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_
270