• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "minddata/dataset/core/config_manager.h"
28 #include "minddata/dataset/engine/consumers/tree_consumer.h"
29 #include "minddata/dataset/engine/data_schema.h"
30 #include "minddata/dataset/engine/datasetops/dataset_op.h"
31 #include "minddata/dataset/engine/datasetops/filter_op.h"
32 #include "minddata/dataset/engine/datasetops/map_op/map_op.h"
33 #include "minddata/dataset/engine/datasetops/project_op.h"
34 #include "minddata/dataset/engine/datasetops/repeat_op.h"
35 #include "minddata/dataset/engine/datasetops/shuffle_op.h"
36 #include "minddata/dataset/engine/datasetops/skip_op.h"
37 #include "minddata/dataset/engine/datasetops/take_op.h"
38 #include "minddata/dataset/engine/ir/cache/dataset_cache.h"
39 #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
40 #include "minddata/dataset/include/dataset/datasets.h"
41 #include "minddata/dataset/util/path.h"
42 #include "minddata/dataset/util/status.h"
43 
44 namespace mindspore {
45 namespace dataset {
46 
47 class Dataset;
48 class DatasetCache;
49 class SamplerObj;
50 class IRNodePass;
51 class DatasetSizeGetter;
52 
53 // Names for non-leaf IR node
54 constexpr char kBatchNode[] = "Batch";
55 constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength";
56 constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab";
57 constexpr char kBuildVocabNode[] = "BuildVocab";
58 constexpr char kCacheLookupNode[] = "CacheLookup";
59 constexpr char kCacheMergeNode[] = "CacheMerge";
60 constexpr char kCacheNode[] = "Cache";
61 constexpr char kConcatNode[] = "Concat";
62 constexpr char kEpochCtrlNode[] = "EpochCtrl";
63 constexpr char kFilterNode[] = "Filter";
64 constexpr char kMapNode[] = "Map";
65 constexpr char kProjectNode[] = "Project";
66 constexpr char kRenameNode[] = "Rename";
67 constexpr char kRepeatNode[] = "Repeat";
68 constexpr char kRootNode[] = "Top";
69 constexpr char kShuffleNode[] = "Shuffle";
70 constexpr char kSkipNode[] = "Skip";
71 constexpr char kSyncWaitNode[] = "SyncWait";
72 constexpr char kTakeNode[] = "Take";
73 constexpr char kTransferNode[] = "Transfer";
74 constexpr char kZipNode[] = "Zip";
75 
76 // Names for leaf IR node
77 constexpr char kAlbumNode[] = "AlbumDataset";
78 constexpr char kCelebANode[] = "CelebADataset";
79 constexpr char kCifar100Node[] = "Cifar100Dataset";
80 constexpr char kCifar10Node[] = "Cifar10Dataset";
81 constexpr char kCityscapesNode[] = "CityscapesDataset";
82 constexpr char kCLUENode[] = "CLUEDataset";
83 constexpr char kCocoNode[] = "CocoDataset";
84 constexpr char kCSVNode[] = "CSVDataset";
85 constexpr char kDIV2KNode[] = "DIV2KDataset";
86 constexpr char kFlickrNode[] = "FlickrDataset";
87 constexpr char kGeneratorNode[] = "GeneratorDataset";
88 constexpr char kImageFolderNode[] = "ImageFolderDataset";
89 constexpr char kManifestNode[] = "ManifestDataset";
90 constexpr char kMindDataNode[] = "MindDataDataset";
91 constexpr char kMnistNode[] = "MnistDataset";
92 constexpr char kRandomNode[] = "RandomDataset";
93 constexpr char kSBUNode[] = "SBUDataset";
94 constexpr char kTextFileNode[] = "TextFileDataset";
95 constexpr char kTFRecordNode[] = "TFRecordDataset";
96 constexpr char kUSPSNode[] = "USPSDataset";
97 constexpr char kVOCNode[] = "VOCDataset";
98 
99 Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
100                     int32_t connector_que_size, std::shared_ptr<DatasetOp> *shuffle_op);
101 
102 // Helper function to validate dataset files parameter
103 Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files);
104 
105 // Helper function to validate dataset num_shards and shard_id parameters
106 Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id);
107 
108 // Helper function to validate dataset sampler parameter
109 Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler);
110 
111 Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
112                            const std::unordered_set<std::string> &valid_strings);
113 
114 // Helper function to validate dataset input/output column parameterCD -
115 Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
116                                   const std::vector<std::string> &columns);
117 
118 // Helper function to validate dataset directory parameter
119 Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir);
120 
121 /// \brief Function to create a sampler for non-mappable dataset (to be used by cache op later).
122 /// \notes Non-mappable dataset does not directly support a sampler. It has provided sampling arguments (shuffle,
123 ///     num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in the pipeline contains
124 ///     a cache. If there is no cache above it, then the sampler is not used.
125 /// \param[in] num_samples The number of samples to be included in the dataset.
126 /// \param[in] shuffle If true, the indices are shuffled.
127 /// \param[in] num_shards Number of shards to divide the dataset into.
128 /// \param[in] shard_id Shard ID of the current shard within num_shards.
129 /// \return Shared pointer to the current Sampler.
130 std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id);
131 
132 // The base class of all IR nodes
133 class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
134   // Allow DeepCopyPass to access internal members
135   friend class DeepCopyPass;
136 
137  public:
138   /// \brief Constructor
139   DatasetNode();
140 
141   /// \brief Constructor that initializes the cache
142   /// \param dataset_cache DatasetCache
143   explicit DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache);
144 
145   /// \brief Destructor
146   ~DatasetNode() = default;
147 
148   /// \brief Node name getter
149   /// \return Name of the current node
150   virtual std::string Name() const = 0;
151 
152   /// \brief Pure virtual function to print the description
153   /// \param out - The output stream to write output to
154   virtual void Print(std::ostream &out) const = 0;
155 
156   /// \brief Pure virtual function to clone a new copy of the node
157   /// \return The new copy of the node
158   virtual std::shared_ptr<DatasetNode> Copy() = 0;
159 
160   /// \brief Print the IR tree to output stream
161   /// \param out - The output stream to write output to
162   void PrintTree(std::ostream &out) const;
163 
164   /// \brief << Stream output operator overload
165   /// \notes This allows you to write the debug print info using stream operators
166   /// \param out - reference to the output stream being overloaded
167   /// \param node - reference to the DatasetNode to display
168   /// \return - the output stream must be returned
169   friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) {
170     node.PrintTree(out);
171     return out;
172   }
173 
174   /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
175   /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
176   /// \return Status Status::OK() if build successfully
177   virtual Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) = 0;
178 
179   /// \brief base virtual function for derived class to implement parameters validation
180   /// \return Status Status::OK() if all the parameters are valid
181   virtual Status ValidateParams();
182 
183   /// \brief Pure virtual function for derived class to get the shard id of specific node
184   /// \return Status Status::OK() if get shard id successfully
185   virtual Status GetShardId(int32_t *const shard_id);
186 
187   /// \brief Gets the dataset size
188   /// \param[in] size_getter Shared pointer to DatasetSizeGetter
189   /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
190   ///     dataset size at the expense of accuracy.
191   /// \return Status - The status code return
192   virtual Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
193                                 int64_t *dataset_size);
194 
195   /// \brief Getter function for child nodes
196   /// \return Child nodes
Children()197   const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
198 
199   /// \brief Establish a parent-child relationship between this node and the input node.
200   ///    Used during the cloning of the user-input IR tree (temporary use)
201   Status AppendChild(std::shared_ptr<DatasetNode> child);
202 
203   /// \brief Insert the input <node> above this node
204   Status InsertAbove(std::shared_ptr<DatasetNode> node);
205 
206   /// \brief Add the input node as the next sibling (future use)
207   Status InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> node);
208 
209   /// \brief detach this node from its parent, add its child (if any) to its parent
210   /// \return error code, return error if node has more than 1 children
211   Status Drop();
212 
213   /// \brief Check if this node has cache
214   /// \return True if the data of this node will be cached
IsCached()215   const bool IsCached() const { return (cache_ != nullptr); }
216 
217   /// \brief Check if this node is a leaf node.
218   /// \return True if this is a leaf node.
IsLeaf()219   const bool IsLeaf() const { return children_.empty(); }
220 
221   /// \brief Check if this node is a unary operator node.
222   /// \return True if this node is semantically a unary operator node
IsUnaryOperator()223   const bool IsUnaryOperator() const { return (mappable_ == kNotADataSource && !nary_op_); }
224 
225   /// \brief Check if this node is a n-ary operator node.
226   /// \return True if this node is semantically a n-ary operator node
IsNaryOperator()227   const bool IsNaryOperator() const { return (mappable_ == kNotADataSource && nary_op_); }
228 
229   /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
230   /// \return True if this node is a mappable dataset
IsMappableDataSource()231   const bool IsMappableDataSource() const { return (mappable_ == kMappableSource); }
232 
233   /// \brief Check if this node is a non-mappable dataset. Only applicable to leaf nodes
234   /// \return True if this node is a non-mappable dataset
IsNonMappableDataSource()235   const bool IsNonMappableDataSource() const { return (mappable_ == kNonMappableSource); }
236 
237   /// \brief Check if this node is a data source node.
238   /// \return True if this node is a data source node
IsDataSource()239   const bool IsDataSource() const { return (mappable_ == kMappableSource || mappable_ == kNonMappableSource); }
240 
241   /// \brief Check if this node is not a data source node.
242   /// \return True if this node is not a data source node
IsNotADataSource()243   const bool IsNotADataSource() const { return (mappable_ == kNotADataSource); }
244 
245   /// \brief Check if this node is a descendant of an operator with cache.
246   /// \return True if a cache-enabled operator is an ancestor of this node
IsDescendantOfCache()247   const bool IsDescendantOfCache() const { return descendant_of_cache_; }
248 
249   /// \brief Check if this node is an orphan node
250   /// \return True if this node isn't nullptr nor does it have any children and a parent
IsOrphanNode(std::shared_ptr<DatasetNode> node)251   static bool IsOrphanNode(std::shared_ptr<DatasetNode> node) {
252     return node != nullptr && node->parent_ == nullptr && node->Children().empty();
253   }
254 
255   /// \brief Mark to indicate this node is a descendant of an operator with cache.
HasCacheAbove()256   void HasCacheAbove() { descendant_of_cache_ = true; }
257 
258   /// \brief Getter of the number of workers
NumWorkers()259   int32_t NumWorkers() { return num_workers_; }
260 
261   /// \brief Getter of dataset cache
GetDatasetCache()262   std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
263 
264   /// \brief Setter function for runtime number of workers
265   /// \param[in] num_workers The number of threads in this operator
266   /// \return Shared pointer to the original object
267   std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);
268 
269   /// \brief Setter function for DatasetCache
270   /// \param[in] cache Shared pointer to DatasetCache
271   /// \return Shared pointer to the original object
272   std::shared_ptr<DatasetNode> SetDatasetCache(const std::shared_ptr<DatasetCache> &cache);
273 
274   /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
275   ///     Similar to shared_from_this, except this one will give you the derived class as shared_ptr
276   /// \return A shared_ptr casted to the derived class
277   template <typename Derived>
shared_from_base()278   std::shared_ptr<Derived> shared_from_base() {
279     return std::static_pointer_cast<Derived>(shared_from_this());
280   }
281 
282   /// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up
283   ///     in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
284   ///     visit on the way back up the tree after its descendants are visited.
285   /// \notes Subclass needs to override this if it requires special node visit access.
286   ///     Check "dataset/engine/opt/pass.h" for more details.
287   /// \param[in] p The node to visit
288   /// \param[out] modified Indicator if the node was modified
289   /// \return Status of the node visit
290   virtual Status Accept(IRNodePass *const p, bool *const modified);
291 
292   /// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited.
293   /// \notes Subclass needs to override this if it requires special node visit access.
294   ///     Check "dataset/engine/opt/pass.h" for more details.
295   /// \param[in] p The node to visit
296   /// \param[out] modified Indicator if the node was modified
297   /// \return Status of the node visit
298   virtual Status AcceptAfter(IRNodePass *const p, bool *const modified);
299 
IsSizeDefined()300   virtual bool IsSizeDefined() { return true; }
301 
302   /// \brief Get the arguments of node
303   /// \param[out] out_json JSON string of all attributes
304   /// \return Status of the function
305   virtual Status to_json(nlohmann::json *out_json);
306 
307   /// \brief Setter function, set the number of total repeats for the operator
SetTotalRepeats(int32_t total_repeats)308   void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; }
309 
310   /// \brief Setter function, set the number of epochs for the operator
SetNumEpochs(int32_t num_epochs)311   void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
312 
313   /// \brief Getter function
314   /// \return The number of required repeats for the operator
GetTotalRepeats()315   int32_t GetTotalRepeats() const { return total_repeats_; }
316 
317   /// \brief Getter function
318   /// \return The number of epochs for the operator
GetNumEpochs()319   int32_t GetNumEpochs() const { return num_epochs_; }
320 
321   /// \brief Getter function
322   /// \return The number of repeats per epoch for the operator
GetNumRepeatsPerEpoch()323   int32_t GetNumRepeatsPerEpoch() const { return total_repeats_ / num_epochs_; }
324 
325  protected:
326   std::vector<std::shared_ptr<DatasetNode>> children_;
327   DatasetNode *parent_;  // used to record the only one parent of an IR node after parsing phase
328   std::shared_ptr<DatasetCache> cache_;
329   int64_t dataset_size_;
330   int32_t num_workers_;
331   int32_t connector_que_size_;
332   int32_t worker_connector_size_;
333   int32_t total_repeats_;  // Number of times required to run this operator
334   int32_t num_epochs_;     // Number of epochs
335   // Establish a parent-child relationship between this node and the input node.
336   // Used only in the constructor of the class and its derived classes.
337   void AddChild(std::shared_ptr<DatasetNode> child);
338   std::string PrintColumns(const std::vector<std::string> &columns) const;
339   void PrintNode(std::ostream &out, int *level) const;
340   enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 };
341   enum DataSource mappable_;
342   bool nary_op_;  // an indicator of whether the current node supports multiple children, true for concat/zip node
343   bool descendant_of_cache_;  // an indicator of whether the current node is a descendant of cache.
344                               // Initially set to false, will set to true by the optimizer when conditions are met.
345 };
346 
347 // MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
348 class MappableSourceNode : public DatasetNode {
349  public:
350   /// \brief Constructor
MappableSourceNode()351   MappableSourceNode() : DatasetNode() { mappable_ = kMappableSource; }
352 
353   /// \brief Constructor that initializes the cache
354   /// \param dataset_cache DatasetCache
MappableSourceNode(const std::shared_ptr<DatasetCache> & dataset_cache)355   explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
356     mappable_ = kMappableSource;
357     // Initially set to false, and set to true by the optimizer when conditions are met.
358     descendant_of_cache_ = false;
359   }
360 
361   Status Accept(IRNodePass *const p, bool *const modified) override;
362 
363   /// \brief Destructor
364   ~MappableSourceNode() = default;
365 
366   /// \brief Node name getter
367   /// \return Name of the current node
368   virtual std::string Name() const = 0;
369 
370   /// \brief Sampler getter
371   /// \return SamplerObj of the current node
372   virtual std::shared_ptr<SamplerObj> Sampler() = 0;
373 
374   /// \brief Sampler setter
375   virtual void SetSampler(std::shared_ptr<SamplerObj> sampler) = 0;
376 };
377 
378 // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
379 class NonMappableSourceNode : public DatasetNode {
380  public:
381   /// \brief Constructor
NonMappableSourceNode()382   NonMappableSourceNode() : DatasetNode() { mappable_ = kNonMappableSource; }
383 
384   /// \brief Constructor that initializes the cache
385   /// \param dataset_cache DatasetCache
NonMappableSourceNode(const std::shared_ptr<DatasetCache> & dataset_cache)386   explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
387     mappable_ = kNonMappableSource;
388     // Initially set to false, and set to true by the optimizer when conditions are met.
389     descendant_of_cache_ = false;
390   }
391 
392   Status Accept(IRNodePass *const p, bool *const modified) override;
393 
394   /// \brief Destructor
395   ~NonMappableSourceNode() = default;
396 
397   /// \brief Node name getter
398   /// \return Name of the current node
399   virtual std::string Name() const = 0;
400 
401   /// \brief By default non-mappable dataset does not support sampling. However, if a cache operator
402   ///     is injected at some other place higher in the tree, that cache can inherit this sampler
403   ///     from the leaf, providing sampling support from the caching layer.
404   ///     This function sets up the sampler for a leaf node that does not use sampling.
405   /// \param[in] sampler The sampler to setup
406   /// \return Status of the function
407   virtual Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) = 0;
408 
409   /// \brief If a cache has been added into the ascendant tree over this non-mappable source node, then the cache will
410   ///     be executing a sampler for fetching the data. As such, any options in the source node need to be reset to its
411   ///     defaults so that this source node will produce the full set of data into the cache.
412   /// \return Status of the function
413   virtual Status MakeSimpleProducer() = 0;
414 };
415 }  // namespace dataset
416 }  // namespace mindspore
417 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
418