1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <map> 22 #include <unordered_map> 23 #include <utility> 24 #include <vector> 25 26 #include "minddata/dataset/engine/tree_adapter.h" 27 #include "minddata/dataset/text/vocab.h" 28 29 namespace mindspore::dataset { 30 // Forward declare 31 class TreeAdapter; 32 class DatasetNode; 33 34 /// A base class for tree consumers which would fetch rows from the tree pipeline 35 class TreeConsumer { 36 public: 37 /// Constructor that prepares an empty tree_adapter 38 TreeConsumer(); 39 40 /// \brief Destructor 41 ~TreeConsumer() = default; 42 /// Initializes the consumer, this involves constructing and preparing the tree. 43 /// \param d The dataset node that represent the root of the IR tree. 44 /// \return Status error code. 45 virtual Status Init(std::shared_ptr<DatasetNode> d); 46 47 /// Internal function to perform the termination 48 /// \return Status error code 49 virtual Status Terminate(); 50 51 protected: 52 /// The class owns the tree_adapter that handles execution tree operations. 53 std::unique_ptr<TreeAdapter> tree_adapter_; 54 /// Method to return the name of the consumer 55 /// \return string 56 virtual std::string Name() = 0; 57 }; 58 59 /// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map 60 class IteratorConsumer : public TreeConsumer { 61 public: 62 /// Constructor which will call the base class default constructor. 63 /// \param num_epochs number of epochs. Default to -1 (infinite epochs). TreeConsumer()64 explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} 65 66 ~IteratorConsumer() = default; 67 68 Status Init(std::shared_ptr<DatasetNode> d) override; 69 70 /// Returns the next row in a vector format 71 /// \param[out] out std::vector of Tensors 72 /// \return Status error code 73 Status GetNextAsVector(std::vector<TensorPtr> *out); 74 75 /// Returns the next row in as a map 76 /// \param[out] out std::map of string to Tensor 77 /// \return Status error code 78 Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *const out); 79 80 /// Returns the next row in as a vector 81 /// \param[out] out std::vector of pairs of string to Tensor 82 /// \return Status error code 83 Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *const vec); 84 85 protected: 86 /// Method to return the name of the consumer 87 /// \return string Name()88 std::string Name() override { return "IteratorConsumer"; } 89 90 private: 91 int32_t num_epochs_; 92 std::map<int32_t, std::string> column_order_; // key: column id, val: column name 93 }; 94 95 #ifndef ENABLE_ANDROID 96 /// Consumer that iterates over the dataset and writes it to disk 97 class SaveToDisk : public TreeConsumer { 98 public: 99 /// Constructor which will call the base class default constructor. 100 /// \param dataset_path path the the dataset 101 /// \param num_files number of files. Default to 1 102 /// \param dataset_type The format of the dataset. Default to "mindrecod". 103 explicit SaveToDisk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") TreeConsumer()104 : TreeConsumer(), dataset_path_(dataset_path), num_files_(num_files), dataset_type_(dataset_type) {} 105 106 ~SaveToDisk() = default; 107 108 /// \brief Parameters validation 109 /// \return Status Status::OK() if all the parameters are valid 110 Status ValidateParams(); 111 112 /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows 113 /// would be written to disk) 114 /// \return Status error code 115 virtual Status Save(); 116 117 protected: 118 /// Method to return the name of the consumer 119 /// \return string Name()120 std::string Name() override { return "SaveToDisk"; } 121 122 private: 123 template <typename T, typename S> 124 Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, 125 std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, 126 std::unique_ptr<S> *s, bool need_convert = false); 127 128 Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map, 129 const TensorRow &row, nlohmann::json *schema, std::vector<std::string> *index_fields); 130 131 Status FetchDataFromTensorRow(const TensorRow &row, 132 const std::unordered_map<std::string, int32_t> &column_name_id_map, 133 nlohmann::json *row_raw_data, 134 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data); 135 136 Status FetchFloatData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data, 137 std::unique_ptr<std::vector<uint8_t>> *data_ptr); 138 139 Status FetchItemData(std::shared_ptr<Tensor> tensor, std::string column_name, nlohmann::json *row_raw_data, 140 std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data); 141 142 template <typename T> 143 bool map_compare(T const &lhs, T const &rhs); 144 145 Status CheckTensorRowShapes(const std::unordered_map<std::string, int32_t> &column_name_id_map, const TensorRow &row, 146 std::map<std::string, std::vector<int>> *PreTensorRowShapes_ptr); 147 148 std::string dataset_path_; 149 int32_t num_files_; 150 std::string dataset_type_; 151 }; 152 #endif 153 154 /// Consumer that iterates over the dataset and send it to a device 155 class ToDevice : public TreeConsumer { 156 public: TreeConsumer()157 explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} 158 159 ~ToDevice() = default; 160 161 Status Init(std::shared_ptr<DatasetNode> d) override; 162 163 Status Terminate() override; 164 165 /// Send the data to device 166 /// \return Status error code 167 virtual Status Send(); 168 169 /// Stop to send data to device 170 /// \return Status error code 171 virtual Status Stop(); 172 173 /// Continue to send data to device 174 /// \return Status error code 175 virtual Status Continue(); 176 177 /// Get data info from TDT 178 /// \return Status error code 179 virtual Status GetDataInfo(std::vector<DataType> *const types, std::vector<TensorShape> *const shapes); 180 181 protected: 182 /// Method to return the name of the consumer 183 /// \return string Name()184 std::string Name() override { return "ToDevice"; } 185 186 private: 187 int32_t num_epochs_; 188 }; 189 190 /// Consumer that is used to get some pipeline information 191 class TreeGetters : public TreeConsumer { 192 public: 193 TreeGetters(); 194 ~TreeGetters() = default; 195 Status Init(std::shared_ptr<DatasetNode> d) override; 196 197 Status GetOutputTypes(std::vector<DataType> *types); 198 Status GetOutputShapes(std::vector<TensorShape> *shapes); 199 Status GetBatchSize(int64_t *batch_size); 200 Status GetRepeatCount(int64_t *repeat_count); 201 Status GetNumClasses(int64_t *num_classes); 202 Status GetColumnNames(std::vector<std::string> *output); 203 Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); Name()204 std::string Name() override { return "TreeGetters"; } 205 virtual Status GetRow(TensorRow *row); 206 207 private: 208 Status GetFirstRowShapeAndType(); 209 210 std::shared_ptr<DatasetNode> root_; 211 int64_t dataset_size_; 212 std::vector<DataType> first_row_type_; 213 std::vector<TensorShape> first_row_shape_; 214 bool first_row_obtained_; // whether first row (which could be empty) is obtained by TreeGetter 215 bool init_flag_; // indicate whether the tree has initialized 216 217 Status InternalInit(); 218 }; 219 220 /// Consumer that is used to get some pipeline information 221 class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_this<DatasetSizeGetter> { 222 public: DatasetSizeGetter()223 DatasetSizeGetter() : dataset_size_(-1) {} 224 ~DatasetSizeGetter() = default; 225 Status Init(std::shared_ptr<DatasetNode> d) override; 226 Status Terminate() override; 227 228 /// \brief Function to get the dataset size 229 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 230 /// dataset size at the expense of accuracy. 231 /// \param[out] dataset_size the size of the dataset 232 /// \return Status of the function 233 Status GetDatasetSize(int64_t *size, bool estimate = false); 234 235 virtual Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row); Name()236 std::string Name() override { return "DatasetSizeGetter"; } 237 238 /// \brief Gets the dataset size by iterating over the entire dataset on a sub tree starting from ir_node 239 /// param[in] ir_node The node that marks the top most of the sub tree on which we want to iterate 240 /// \return Status - The status code return 241 Status DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size); 242 243 private: 244 std::shared_ptr<DatasetNode> root_; 245 std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_; // this is vector to handle different branch of zip 246 int64_t dataset_size_; 247 }; 248 249 class BuildVocabConsumer : public TreeConsumer { 250 public: 251 /// BuildVocabConsumer Constructor which will call the base class default constructor. 252 BuildVocabConsumer() = default; 253 254 ~BuildVocabConsumer() = default; 255 256 Status Init(std::shared_ptr<DatasetNode> d) override; 257 258 /// Start consuming 259 /// \return Status error code 260 virtual Status Start(); 261 262 protected: 263 /// Method to return the name of the consumer 264 /// \return string Name()265 std::string Name() override { return "BuildVocab"; } 266 }; 267 268 } // namespace mindspore::dataset 269 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_ 270