1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 27 #include "minddata/dataset/core/config_manager.h" 28 #include "minddata/dataset/engine/consumers/tree_consumer.h" 29 #include "minddata/dataset/engine/data_schema.h" 30 #include "minddata/dataset/engine/datasetops/dataset_op.h" 31 #include "minddata/dataset/engine/datasetops/filter_op.h" 32 #include "minddata/dataset/engine/datasetops/map_op/map_op.h" 33 #include "minddata/dataset/engine/datasetops/project_op.h" 34 #include "minddata/dataset/engine/datasetops/repeat_op.h" 35 #include "minddata/dataset/engine/datasetops/shuffle_op.h" 36 #include "minddata/dataset/engine/datasetops/skip_op.h" 37 #include "minddata/dataset/engine/datasetops/take_op.h" 38 #include "minddata/dataset/engine/ir/cache/dataset_cache.h" 39 #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" 40 #include "minddata/dataset/include/dataset/datasets.h" 41 #include "minddata/dataset/util/path.h" 42 #include "minddata/dataset/util/status.h" 43 44 namespace mindspore { 45 namespace dataset { 46 47 class Dataset; 48 class DatasetCache; 49 class SamplerObj; 50 class IRNodePass; 51 class DatasetSizeGetter; 52 53 // Names for non-leaf IR node 54 constexpr char kBatchNode[] = "Batch"; 55 constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength"; 56 constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab"; 57 constexpr char kBuildVocabNode[] = "BuildVocab"; 58 constexpr char kCacheLookupNode[] = "CacheLookup"; 59 constexpr char kCacheMergeNode[] = "CacheMerge"; 60 constexpr char kCacheNode[] = "Cache"; 61 constexpr char kConcatNode[] = "Concat"; 62 constexpr char kEpochCtrlNode[] = "EpochCtrl"; 63 constexpr char kFilterNode[] = "Filter"; 64 constexpr char kMapNode[] = "Map"; 65 constexpr char kProjectNode[] = "Project"; 66 constexpr char kRenameNode[] = "Rename"; 67 constexpr char kRepeatNode[] = "Repeat"; 68 constexpr char kRootNode[] = "Top"; 69 constexpr char kShuffleNode[] = "Shuffle"; 70 constexpr char kSkipNode[] = "Skip"; 71 constexpr char kSyncWaitNode[] = "SyncWait"; 72 constexpr char kTakeNode[] = "Take"; 73 constexpr char kTransferNode[] = "Transfer"; 74 constexpr char kZipNode[] = "Zip"; 75 76 // Names for leaf IR node 77 constexpr char kAlbumNode[] = "AlbumDataset"; 78 constexpr char kCelebANode[] = "CelebADataset"; 79 constexpr char kCifar100Node[] = "Cifar100Dataset"; 80 constexpr char kCifar10Node[] = "Cifar10Dataset"; 81 constexpr char kCityscapesNode[] = "CityscapesDataset"; 82 constexpr char kCLUENode[] = "CLUEDataset"; 83 constexpr char kCocoNode[] = "CocoDataset"; 84 constexpr char kCSVNode[] = "CSVDataset"; 85 constexpr char kDIV2KNode[] = "DIV2KDataset"; 86 constexpr char kFlickrNode[] = "FlickrDataset"; 87 constexpr char kGeneratorNode[] = "GeneratorDataset"; 88 constexpr char kImageFolderNode[] = "ImageFolderDataset"; 89 constexpr char kManifestNode[] = "ManifestDataset"; 90 constexpr char kMindDataNode[] = "MindDataDataset"; 91 constexpr char kMnistNode[] = "MnistDataset"; 92 constexpr char kRandomNode[] = "RandomDataset"; 93 constexpr char kSBUNode[] = "SBUDataset"; 94 constexpr char kTextFileNode[] = "TextFileDataset"; 95 constexpr char kTFRecordNode[] = "TFRecordDataset"; 96 constexpr char kUSPSNode[] = "USPSDataset"; 97 constexpr char kVOCNode[] = "VOCDataset"; 98 99 Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, 100 int32_t connector_que_size, std::shared_ptr<DatasetOp> *shuffle_op); 101 102 // Helper function to validate dataset files parameter 103 Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files); 104 105 // Helper function to validate dataset num_shards and shard_id parameters 106 Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id); 107 108 // Helper function to validate dataset sampler parameter 109 Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler); 110 111 Status ValidateStringValue(const std::string &dataset_name, const std::string &str, 112 const std::unordered_set<std::string> &valid_strings); 113 114 // Helper function to validate dataset input/output column parameterCD - 115 Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, 116 const std::vector<std::string> &columns); 117 118 // Helper function to validate dataset directory parameter 119 Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir); 120 121 /// \brief Function to create a sampler for non-mappable dataset (to be used by cache op later). 122 /// \notes Non-mappable dataset does not directly support a sampler. It has provided sampling arguments (shuffle, 123 /// num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in the pipeline contains 124 /// a cache. If there is no cache above it, then the sampler is not used. 125 /// \param[in] num_samples The number of samples to be included in the dataset. 126 /// \param[in] shuffle If true, the indices are shuffled. 127 /// \param[in] num_shards Number of shards to divide the dataset into. 128 /// \param[in] shard_id Shard ID of the current shard within num_shards. 129 /// \return Shared pointer to the current Sampler. 130 std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id); 131 132 // The base class of all IR nodes 133 class DatasetNode : public std::enable_shared_from_this<DatasetNode> { 134 // Allow DeepCopyPass to access internal members 135 friend class DeepCopyPass; 136 137 public: 138 /// \brief Constructor 139 DatasetNode(); 140 141 /// \brief Constructor that initializes the cache 142 /// \param dataset_cache DatasetCache 143 explicit DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache); 144 145 /// \brief Destructor 146 ~DatasetNode() = default; 147 148 /// \brief Node name getter 149 /// \return Name of the current node 150 virtual std::string Name() const = 0; 151 152 /// \brief Pure virtual function to print the description 153 /// \param out - The output stream to write output to 154 virtual void Print(std::ostream &out) const = 0; 155 156 /// \brief Pure virtual function to clone a new copy of the node 157 /// \return The new copy of the node 158 virtual std::shared_ptr<DatasetNode> Copy() = 0; 159 160 /// \brief Print the IR tree to output stream 161 /// \param out - The output stream to write output to 162 void PrintTree(std::ostream &out) const; 163 164 /// \brief << Stream output operator overload 165 /// \notes This allows you to write the debug print info using stream operators 166 /// \param out - reference to the output stream being overloaded 167 /// \param node - reference to the DatasetNode to display 168 /// \return - the output stream must be returned 169 friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) { 170 node.PrintTree(out); 171 return out; 172 } 173 174 /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object 175 /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create 176 /// \return Status Status::OK() if build successfully 177 virtual Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) = 0; 178 179 /// \brief base virtual function for derived class to implement parameters validation 180 /// \return Status Status::OK() if all the parameters are valid 181 virtual Status ValidateParams(); 182 183 /// \brief Pure virtual function for derived class to get the shard id of specific node 184 /// \return Status Status::OK() if get shard id successfully 185 virtual Status GetShardId(int32_t *const shard_id); 186 187 /// \brief Gets the dataset size 188 /// \param[in] size_getter Shared pointer to DatasetSizeGetter 189 /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting 190 /// dataset size at the expense of accuracy. 191 /// \return Status - The status code return 192 virtual Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, 193 int64_t *dataset_size); 194 195 /// \brief Getter function for child nodes 196 /// \return Child nodes Children()197 const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; } 198 199 /// \brief Establish a parent-child relationship between this node and the input node. 200 /// Used during the cloning of the user-input IR tree (temporary use) 201 Status AppendChild(std::shared_ptr<DatasetNode> child); 202 203 /// \brief Insert the input <node> above this node 204 Status InsertAbove(std::shared_ptr<DatasetNode> node); 205 206 /// \brief Add the input node as the next sibling (future use) 207 Status InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> node); 208 209 /// \brief detach this node from its parent, add its child (if any) to its parent 210 /// \return error code, return error if node has more than 1 children 211 Status Drop(); 212 213 /// \brief Check if this node has cache 214 /// \return True if the data of this node will be cached IsCached()215 const bool IsCached() const { return (cache_ != nullptr); } 216 217 /// \brief Check if this node is a leaf node. 218 /// \return True if this is a leaf node. IsLeaf()219 const bool IsLeaf() const { return children_.empty(); } 220 221 /// \brief Check if this node is a unary operator node. 222 /// \return True if this node is semantically a unary operator node IsUnaryOperator()223 const bool IsUnaryOperator() const { return (mappable_ == kNotADataSource && !nary_op_); } 224 225 /// \brief Check if this node is a n-ary operator node. 226 /// \return True if this node is semantically a n-ary operator node IsNaryOperator()227 const bool IsNaryOperator() const { return (mappable_ == kNotADataSource && nary_op_); } 228 229 /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes 230 /// \return True if this node is a mappable dataset IsMappableDataSource()231 const bool IsMappableDataSource() const { return (mappable_ == kMappableSource); } 232 233 /// \brief Check if this node is a non-mappable dataset. Only applicable to leaf nodes 234 /// \return True if this node is a non-mappable dataset IsNonMappableDataSource()235 const bool IsNonMappableDataSource() const { return (mappable_ == kNonMappableSource); } 236 237 /// \brief Check if this node is a data source node. 238 /// \return True if this node is a data source node IsDataSource()239 const bool IsDataSource() const { return (mappable_ == kMappableSource || mappable_ == kNonMappableSource); } 240 241 /// \brief Check if this node is not a data source node. 242 /// \return True if this node is not a data source node IsNotADataSource()243 const bool IsNotADataSource() const { return (mappable_ == kNotADataSource); } 244 245 /// \brief Check if this node is a descendant of an operator with cache. 246 /// \return True if a cache-enabled operator is an ancestor of this node IsDescendantOfCache()247 const bool IsDescendantOfCache() const { return descendant_of_cache_; } 248 249 /// \brief Check if this node is an orphan node 250 /// \return True if this node isn't nullptr nor does it have any children and a parent IsOrphanNode(std::shared_ptr<DatasetNode> node)251 static bool IsOrphanNode(std::shared_ptr<DatasetNode> node) { 252 return node != nullptr && node->parent_ == nullptr && node->Children().empty(); 253 } 254 255 /// \brief Mark to indicate this node is a descendant of an operator with cache. HasCacheAbove()256 void HasCacheAbove() { descendant_of_cache_ = true; } 257 258 /// \brief Getter of the number of workers NumWorkers()259 int32_t NumWorkers() { return num_workers_; } 260 261 /// \brief Getter of dataset cache GetDatasetCache()262 std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; } 263 264 /// \brief Setter function for runtime number of workers 265 /// \param[in] num_workers The number of threads in this operator 266 /// \return Shared pointer to the original object 267 std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); 268 269 /// \brief Setter function for DatasetCache 270 /// \param[in] cache Shared pointer to DatasetCache 271 /// \return Shared pointer to the original object 272 std::shared_ptr<DatasetNode> SetDatasetCache(const std::shared_ptr<DatasetCache> &cache); 273 274 /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> 275 /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr 276 /// \return A shared_ptr casted to the derived class 277 template <typename Derived> shared_from_base()278 std::shared_ptr<Derived> shared_from_base() { 279 return std::static_pointer_cast<Derived>(shared_from_this()); 280 } 281 282 /// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up 283 /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node 284 /// visit on the way back up the tree after its descendants are visited. 285 /// \notes Subclass needs to override this if it requires special node visit access. 286 /// Check "dataset/engine/opt/pass.h" for more details. 287 /// \param[in] p The node to visit 288 /// \param[out] modified Indicator if the node was modified 289 /// \return Status of the node visit 290 virtual Status Accept(IRNodePass *const p, bool *const modified); 291 292 /// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited. 293 /// \notes Subclass needs to override this if it requires special node visit access. 294 /// Check "dataset/engine/opt/pass.h" for more details. 295 /// \param[in] p The node to visit 296 /// \param[out] modified Indicator if the node was modified 297 /// \return Status of the node visit 298 virtual Status AcceptAfter(IRNodePass *const p, bool *const modified); 299 IsSizeDefined()300 virtual bool IsSizeDefined() { return true; } 301 302 /// \brief Get the arguments of node 303 /// \param[out] out_json JSON string of all attributes 304 /// \return Status of the function 305 virtual Status to_json(nlohmann::json *out_json); 306 307 /// \brief Setter function, set the number of total repeats for the operator SetTotalRepeats(int32_t total_repeats)308 void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; } 309 310 /// \brief Setter function, set the number of epochs for the operator SetNumEpochs(int32_t num_epochs)311 void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; } 312 313 /// \brief Getter function 314 /// \return The number of required repeats for the operator GetTotalRepeats()315 int32_t GetTotalRepeats() const { return total_repeats_; } 316 317 /// \brief Getter function 318 /// \return The number of epochs for the operator GetNumEpochs()319 int32_t GetNumEpochs() const { return num_epochs_; } 320 321 /// \brief Getter function 322 /// \return The number of repeats per epoch for the operator GetNumRepeatsPerEpoch()323 int32_t GetNumRepeatsPerEpoch() const { return total_repeats_ / num_epochs_; } 324 325 protected: 326 std::vector<std::shared_ptr<DatasetNode>> children_; 327 DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase 328 std::shared_ptr<DatasetCache> cache_; 329 int64_t dataset_size_; 330 int32_t num_workers_; 331 int32_t connector_que_size_; 332 int32_t worker_connector_size_; 333 int32_t total_repeats_; // Number of times required to run this operator 334 int32_t num_epochs_; // Number of epochs 335 // Establish a parent-child relationship between this node and the input node. 336 // Used only in the constructor of the class and its derived classes. 337 void AddChild(std::shared_ptr<DatasetNode> child); 338 std::string PrintColumns(const std::vector<std::string> &columns) const; 339 void PrintNode(std::ostream &out, int *level) const; 340 enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 }; 341 enum DataSource mappable_; 342 bool nary_op_; // an indicator of whether the current node supports multiple children, true for concat/zip node 343 bool descendant_of_cache_; // an indicator of whether the current node is a descendant of cache. 344 // Initially set to false, will set to true by the optimizer when conditions are met. 345 }; 346 347 // MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes. 348 class MappableSourceNode : public DatasetNode { 349 public: 350 /// \brief Constructor MappableSourceNode()351 MappableSourceNode() : DatasetNode() { mappable_ = kMappableSource; } 352 353 /// \brief Constructor that initializes the cache 354 /// \param dataset_cache DatasetCache MappableSourceNode(const std::shared_ptr<DatasetCache> & dataset_cache)355 explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) { 356 mappable_ = kMappableSource; 357 // Initially set to false, and set to true by the optimizer when conditions are met. 358 descendant_of_cache_ = false; 359 } 360 361 Status Accept(IRNodePass *const p, bool *const modified) override; 362 363 /// \brief Destructor 364 ~MappableSourceNode() = default; 365 366 /// \brief Node name getter 367 /// \return Name of the current node 368 virtual std::string Name() const = 0; 369 370 /// \brief Sampler getter 371 /// \return SamplerObj of the current node 372 virtual std::shared_ptr<SamplerObj> Sampler() = 0; 373 374 /// \brief Sampler setter 375 virtual void SetSampler(std::shared_ptr<SamplerObj> sampler) = 0; 376 }; 377 378 // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. 379 class NonMappableSourceNode : public DatasetNode { 380 public: 381 /// \brief Constructor NonMappableSourceNode()382 NonMappableSourceNode() : DatasetNode() { mappable_ = kNonMappableSource; } 383 384 /// \brief Constructor that initializes the cache 385 /// \param dataset_cache DatasetCache NonMappableSourceNode(const std::shared_ptr<DatasetCache> & dataset_cache)386 explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) { 387 mappable_ = kNonMappableSource; 388 // Initially set to false, and set to true by the optimizer when conditions are met. 389 descendant_of_cache_ = false; 390 } 391 392 Status Accept(IRNodePass *const p, bool *const modified) override; 393 394 /// \brief Destructor 395 ~NonMappableSourceNode() = default; 396 397 /// \brief Node name getter 398 /// \return Name of the current node 399 virtual std::string Name() const = 0; 400 401 /// \brief By default non-mappable dataset does not support sampling. However, if a cache operator 402 /// is injected at some other place higher in the tree, that cache can inherit this sampler 403 /// from the leaf, providing sampling support from the caching layer. 404 /// This function sets up the sampler for a leaf node that does not use sampling. 405 /// \param[in] sampler The sampler to setup 406 /// \return Status of the function 407 virtual Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) = 0; 408 409 /// \brief If a cache has been added into the ascendant tree over this non-mappable source node, then the cache will 410 /// be executing a sampler for fetching the data. As such, any options in the source node need to be reset to its 411 /// defaults so that this source node will produce the full set of data into the cache. 412 /// \return Status of the function 413 virtual Status MakeSimpleProducer() = 0; 414 }; 415 } // namespace dataset 416 } // namespace mindspore 417 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ 418