• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
18 
19 #include <algorithm>
20 #include <limits>
21 #include <memory>
22 #include <set>
23 #include <utility>
24 
25 #include "minddata/dataset/engine/opt/pass.h"
26 #include "minddata/dataset/util/random.h"
27 #include "minddata/dataset/util/status.h"
28 
29 namespace mindspore {
30 namespace dataset {
31 
32 // Helper function to compute a default shuffle size
ComputeShuffleSize(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int64_t * shuffle_size)33 Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
34                           int64_t *shuffle_size) {
35   RETURN_UNEXPECTED_IF_NULL(shuffle_size);
36   const int64_t average_files_multiplier = 4;
37   const int64_t shuffle_max = 10000;
38 
39   // Adjust the num rows per shard if sharding was given
40   if (num_devices > 0) {
41     if (num_rows % num_devices == 0) {
42       num_rows = num_rows / num_devices;
43     } else {
44       num_rows = (num_rows / num_devices) + 1;
45     }
46   }
47 
48   // Cap based on total rows directive.  Some ops do not have this and give value of 0.
49   if (total_rows > 0) {
50     num_rows = std::min(num_rows, total_rows);
51   }
52 
53   // get the average per file
54   CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0.");
55   int64_t avg_rows_per_file = num_rows / num_files;
56 
57   if (avg_rows_per_file != 0) {
58     *shuffle_size = std::min(avg_rows_per_file * average_files_multiplier, shuffle_max);
59   } else {
60     *shuffle_size = shuffle_max;
61   }
62 
63   return Status::OK();
64 }
65 
66 // Helper function to inject a shuffle operator over top of current operator being built
AddShuffleOp(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int32_t connector_que_size,std::shared_ptr<ShuffleOp> * shuffle_op)67 Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
68                     int32_t connector_que_size, std::shared_ptr<ShuffleOp> *shuffle_op) {
69   RETURN_UNEXPECTED_IF_NULL(shuffle_op);
70   int64_t shuffle_size = 0;
71   RETURN_IF_NOT_OK(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
72   MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
73   // Add the shuffle op
74   *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true);
75   return Status::OK();
76 }
77 
78 // Helper function to validate dataset directory parameter
ValidateDatasetDirParam(const std::string & dataset_name,std::string dataset_dir)79 Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
80   if (dataset_dir.empty()) {
81     std::string err_msg = dataset_name + ": dataset_dir is not specified.";
82     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
83   }
84 
85   std::string real_path;
86   RETURN_IF_NOT_OK(Path::RealPath(dataset_dir, real_path));
87   Path dir(dataset_dir);
88   if (!dir.IsDirectory()) {
89     std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
90     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
91   }
92 
93   if (access(dataset_dir.c_str(), R_OK) == -1) {
94     std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
95     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
96   }
97 
98   return Status::OK();
99 }
100 
101 // Helper function to validate dataset files parameter
ValidateDatasetFilesParam(const std::string & dataset_name,const std::vector<std::string> & dataset_files,const std::string & file_name)102 Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files,
103                                  const std::string &file_name) {
104   if (dataset_files.empty()) {
105     std::string err_msg = dataset_name + ": dataset files list is empty, check input dataset_files.";
106     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
107   }
108 
109   for (const auto &f : dataset_files) {
110     Path dataset_file(f);
111     if (!dataset_file.Exists()) {
112       std::string err_msg = dataset_name + ": " + file_name + ": [" + f + "] is invalid or does not exist.";
113       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
114     }
115     if (access(dataset_file.ToString().c_str(), R_OK) == -1) {
116       std::string err_msg = dataset_name + ": No access to specified " + file_name + ": " + f;
117       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
118     }
119   }
120 
121   return Status::OK();
122 }
123 
124 // Helper function to validate dataset num_shards and shard_id parameters
ValidateDatasetShardParams(const std::string & dataset_name,int32_t num_shards,int32_t shard_id)125 Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
126   if (num_shards <= 0) {
127     std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
128     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
129   }
130 
131   if (shard_id < 0 || shard_id >= num_shards) {
132     // num_shards
133     std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
134                           ", num_shards: " + std::to_string(num_shards);
135     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
136   }
137 
138   return Status::OK();
139 }
140 
141 // Helper function to validate dataset sampler parameter
ValidateDatasetSampler(const std::string & dataset_name,const std::shared_ptr<SamplerObj> & sampler)142 Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
143   if (sampler == nullptr) {
144     std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
145     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
146   }
147   RETURN_IF_NOT_OK(sampler->ValidateParams());
148   return Status::OK();
149 }
150 
ValidateStringValue(const std::string & dataset_name,const std::string & str,const std::unordered_set<std::string> & valid_strings)151 Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
152                            const std::unordered_set<std::string> &valid_strings) {
153   if (valid_strings.find(str) == valid_strings.end()) {
154     std::string init;
155     std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init,
156                                        [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
157     std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
158     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
159   }
160   return Status::OK();
161 }
162 
163 // Helper function to validate dataset input/output column parameter
ValidateDatasetColumnParam(const std::string & dataset_name,const std::string & column_param,const std::vector<std::string> & columns)164 Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
165                                   const std::vector<std::string> &columns) {
166   if (columns.empty()) {
167     std::string err_msg = dataset_name + ": '" + column_param + "' should not be empty string";
168     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
169   }
170   for (uint32_t i = 0; i < columns.size(); ++i) {
171     if (columns[i].empty()) {
172       std::string err_msg = dataset_name + ": '" + column_param + "' [" + std::to_string(i) + "] must not be empty";
173       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
174     }
175   }
176   std::set<std::string> columns_set;
177   for (auto &column_name : columns) {
178     auto result = columns_set.insert(column_name);
179     if (result.second == false) {
180       std::string err_msg = dataset_name + ": '" + column_param +
181                             "' : Invalid parameter, duplicate column names are not allowed: " + *result.first;
182       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
183     }
184   }
185   return Status::OK();
186 }
187 
ValidateMapKey(const std::string & dataset_name,const std::string & key,const std::map<std::string,std::vector<std::string>> & map)188 Status ValidateMapKey(const std::string &dataset_name, const std::string &key,
189                       const std::map<std::string, std::vector<std::string>> &map) {
190   if (map.find(key) == map.end()) {
191     std::string init;
192     std::string mode = std::accumulate(map.begin(), map.end(), init,
193                                        [](std::string a, auto b) { return std::move(a) + " " + std::move(b.first); });
194     std::string err_msg = dataset_name + ": " + key + " does not match any key in [" + mode + " ]";
195     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
196   }
197   return Status::OK();
198 }
199 
ValidateMapValue(const std::string & dataset_name,const std::string & str,const std::vector<std::string> & valid_strings)200 Status ValidateMapValue(const std::string &dataset_name, const std::string &str,
201                         const std::vector<std::string> &valid_strings) {
202   if (find(valid_strings.begin(), valid_strings.end(), str) == valid_strings.end()) {
203     std::string init;
204     std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init,
205                                        [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
206     std::string err_msg = dataset_name + ": " + str + " does not match any string in [" + mode + " ]";
207     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
208   }
209   return Status::OK();
210 }
211 
SelectSampler(int64_t num_samples,bool shuffle,int32_t num_shards,int32_t shard_id)212 std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
213   if (shuffle) {
214     if (num_shards > 1) {
215       // If shuffle enabled, sharding enabled, use distributed random sampler
216       return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
217     }
218     // If shuffle enabled, sharding disabled, use random sampler
219     return RandomSampler(num_samples >= 0, num_samples).Parse();
220   }
221   if (num_shards > 1) {
222     // If shuffle disabled, sharding enabled, use distributed sequential sampler
223     return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
224   }
225   // If shuffle disabled, sharding disabled, use sequential sampler
226   return SequentialSampler(0, num_samples).Parse();
227 }
228 
229 // Constructor to initialize the cache
DatasetNode(const std::shared_ptr<DatasetCache> & dataset_cache)230 DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }
231 
SetNumWorkers(int32_t num_workers)232 std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
233   num_workers_ = num_workers;
234   return shared_from_this();
235 }
236 
SetConnectorQueueSize(int32_t connector_queue_size)237 std::shared_ptr<DatasetNode> DatasetNode::SetConnectorQueueSize(int32_t connector_queue_size) {
238   connector_que_size_ = connector_queue_size;
239   return shared_from_this();
240 }
241 
SetDatasetCache(const std::shared_ptr<DatasetCache> & cache)242 std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) {
243   cache_ = cache;
244   return shared_from_this();
245 }
246 
DatasetNode()247 DatasetNode::DatasetNode()
248     : cache_(nullptr),
249       parent_(nullptr),
250       children_({}),
251       dataset_size_(-1),
252       mappable_(kNotADataSource),
253       nary_op_(false),
254       descendant_of_cache_(false),
255       total_repeats_(-1),
256       num_epochs_(1) {
257   // Fetch some default value from config manager
258   std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
259   num_workers_ = cfg->num_parallel_workers();
260   connector_que_size_ = cfg->op_connector_size();
261   worker_connector_size_ = cfg->worker_connector_size();
262 }
263 
PrintColumns(const std::vector<std::string> & columns) const264 std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
265   std::string me;
266   if (columns.empty()) {
267     me = "<nil>";
268   } else {
269     me = "[";
270     auto i = 0;
271     for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
272       me += *it;
273       if (i < columns.size() - 1) {
274         me += ", ";
275       } else {
276         me += "]";
277       }
278     }
279   }
280   return me;
281 }
282 
PrintTree(std::ostream & out) const283 void DatasetNode::PrintTree(std::ostream &out) const {
284   int level = 0;
285   PrintNode(out, &level);
286 }
287 
PrintNode(std::ostream & out,int * level) const288 void DatasetNode::PrintNode(std::ostream &out, int *level) const {
289   const std::string prefix = "+-";
290   const std::string indent = "| ";
291   out << prefix;
292   Print(out);
293   for (const auto &c : this->Children()) {
294     out << '\n';
295     ++(*level);
296     for (auto i = 0; i < *level; i++) {
297       out << indent;
298     }
299     c->PrintNode(out, level);
300     --(*level);
301   }
302 }
303 
304 // Add a node as a child, node's parent needs to be empty
305 // This function will allow child to be a nullptr, in which case it will simply skip.
306 // This function is used only when building IR node one by one from parsing the user code.
307 // During the parsing, we allow a node to have more than one parent, possibly forming a graph.
308 // It does not maintain the parent_ attribute of the node, which enforces a single parent and a tree structure.
AddChild(std::shared_ptr<DatasetNode> child)309 void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
310   if (child != nullptr) {
311     children_.push_back(child);
312   }
313 }
314 
315 /*
316  * AppendChild(<node>) appending <node> as the last child of this node. The new node must have no parent.
317  *
318  * Input tree:
319  *      ds4
320  *     /   \
321  *   ds3   ds2
322  *     |
323  *    ds1
324  *
325  * ds4->AppendChild(ds6) yields this tree
326  *
327  *      _ ds4 _
328  *     /   |   \
329  *   ds3  ds2  ds6
330  *    |
331  *   ds1
332  *
333  */
AppendChild(std::shared_ptr<DatasetNode> child)334 Status DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
335   CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
336   CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
337                                "This node must be a unary operator with no child or an n-ary operator");
338   children_.push_back(child);
339   child->parent_ = this;
340   return Status::OK();
341 }
342 
343 /*
344  * InsertChildAt(<pos>, <node>) inserts the <node> to be at the <pos> index of the vector of its child nodes.
345  * As in the convention of C++, <pos> starts at position 0.
346  * If the <pos> is a negative number or larger than the size of the vector minus one, an error is raised.
347  */
InsertChildAt(int32_t pos,std::shared_ptr<DatasetNode> child)348 Status DatasetNode::InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> child) {
349   CHECK_FAIL_RETURN_UNEXPECTED(pos > -1 && pos <= children_.size(), "Position must in the range of [0, size]");
350   CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
351   CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
352                                "This node must be a unary operator with no child or an n-ary operator");
353   children_.insert(children_.begin() + pos, child);
354   child->parent_ = this;
355   return Status::OK();
356 }
357 
358 /*
359  * Insert the input <node> above this node
360  * Input tree:
361  *       ds4
362  *      /   \
363  *     ds3  ds2
364  *      |
365  *     ds1
366  *
367  * Case 1: If we want to insert a new node ds5 between ds4 and ds3, use
368  *           ds3->InsertAbove(ds5)
369  *
370  *       ds4
371  *      /   \
372  *     ds5  ds2
373  *      |
374  *     ds3
375  *      |
376  *     ds1
377  *
378  * Case 2: Likewise, ds2->InsertAbove(ds6) yields
379  *
380  *       ds4
381  *      /   \
382  *     ds3  ds6
383  *      |    |
384  *     ds1  ds2
385  *
386  * Case 3: We can insert a new node between ds3 and ds1 by ds1->InsertAbove(ds7)
387  *
388  *       ds4
389  *      /   \
390  *     ds3  ds2
391  *      |
392  *     ds7
393  *      |
394  *     ds1
395  *
396  * InsertAbove() cannot use on the root node of a tree.
397  */
InsertAbove(std::shared_ptr<DatasetNode> node)398 Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
399   CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(node), "Node to insert must be an orphan node.");
400   CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must not be the root or a node without parent.");
401   auto parent = parent_;
402 
403   // The following fields of these three nodes are changed in this function:
404   // 1. parent->children_
405   // 2. node->parent_ and node->children_
406   // 3. this->parent_
407   auto current_node_itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
408   *current_node_itr = node;  // replace me in my parent's children list with the newly inserted node
409   node->parent_ = parent;    // set the newly inserted node's parent ptr to my parent
410   node->children_.push_back(shared_from_this());  // add myself to the newly inserted node's children list
411   parent_ = node.get();                           // set my parent ptr to the newly inserted node
412 
413   return Status::OK();
414 }
415 
416 /*
417  * Drop() detaches this node from the tree it is in. Calling Drop() from a standalone node is a no-op.
418  *
419  * Input tree:
420  *       ds10
421  *      /    \
422  *    ds9    ds6
423  *     |   /  |  \
424  *    ds8 ds5 ds4 ds1
425  *     |     /  \
426  *    ds7  ds3  ds2
427  *
428  * Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
429  *
430  *   ds7->Drop() yields the tree below:
431  *
432  *       ds10
433  *      /    \
434  *    ds9    ds6
435  *     |   /  |  \
436  *    ds8 ds5 ds4 ds1
437  *           /  \
438  *         ds3  ds2
439  *
440  * Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
441  *         becomes its parent's child.
442  *
443  *   ds8->Drop() yields the tree below:
444  *
445  *       ds10
446  *      /    \
447  *    ds9    ds6
448  *     |   /  |  \
449  *    ds7 ds5 ds4 ds1
450  *           /  \
451  *         ds3  ds2
452  *
453  * Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and the node's
454  *         children become its parent's children.
455  *
456  *   When the input tree is
457  *
458  *       ds10
459  *      /    \
460  *    ds9    ds6
461  *     |      |
462  *    ds8    ds4
463  *     |    /   \
464  *    ds7  ds3  ds2
465  *
466  *    ds4->Drop() yields the tree below:
467  *
468  *       ds10
469  *      /    \
470  *    ds9    ds6
471  *     |     /  \
472  *    ds8  ds3  ds2
473  *     |
474  *    ds7
475  *
476  *   But if ds6 is not an n-ary operator, ds4->Drop() will raise an error because we cannot add the children of an
477  *   n-ary operator (ds4) to a unary operator (ds6).
478  *
479  * Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will be
480  *         squeezed left.
481  *
482  * Input tree:
483  *       ds10
484  *      /    \
485  *    ds9    ds6
486  *     |   /  |  \
487  *    ds8 ds5 ds4 ds1
488  *     |     /  \
489  *    ds7  ds3  ds2
490  *
491  *   ds5->Drop() yields the tree below:
492  *
493  *       ds10
494  *      /    \
495  *    ds9    ds6
496  *     |     /  \
497  *    ds8   ds4 ds1
498  *     |    /  \
499  *    ds7 ds3  ds2
500  *
501  * Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
502  *         children become its parent's children.
503  *
504  * Input tree:
505  *       ds10
506  *      /    \
507  *    ds9    ds6
508  *     |   /  |  \
509  *    ds8 ds5 ds4 ds1
510  *     |      |
511  *    ds7     ds3
512  *
513  *   ds4->Drop() yields the tree below:
514  *
515  *       ds10
516  *      /    \
517  *    ds9    ds6
518  *     |   /  |  \
519  *    ds8 ds5 ds3 ds1
520  *     |
521  *    ds7
522  *
523  * Case 6: When the node has more than one child and more than one sibling, Drop() will raise an error.
524  *         If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it
525  *         with a combination of Drop(), InsertChildAt()
526  *
527  * Input tree:
528  *       ds10
529  *      /    \
530  *    ds9    ds6
531  *     |   /  |  \
532  *    ds8 ds5 ds4 ds1
533  *     |     /  \
534  *    ds7  ds3  ds2
535  *
536  * If we want to form this tree below:
537  *
538  *       ds10
539  *      /    \
540  *    ds9    ds6_____
541  *     |   /  |   |  \
542  *    ds8 ds5 ds3 ds2 ds1
543  *     |
544  *    ds7
545  *
546  */
Drop()547 Status DatasetNode::Drop() {
548   CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node to drop must not be the root or a node without parent.");
549   CHECK_FAIL_RETURN_UNEXPECTED(!(IsNaryOperator() && parent_->IsUnaryOperator()),
550                                "Trying to drop an n-ary operator that is a child of a unary operator");
551   CHECK_FAIL_RETURN_UNEXPECTED(!(children_.size() > 1 && parent_->children_.size() > 1),
552                                "This node to drop must not have more than one child and more than one sibling.");
553   if (parent_->children_.size() == 1) {
554     auto my_parent = parent_;
555     // Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
556     //         becomes its parent's child.
557     // This is the most common use case.
558     if (children_.size() == 1) {
559       auto child = children_[0];
560       // Move its child to be its parent's child
561       my_parent->children_[0] = child;
562       child->parent_ = my_parent;
563     } else if (children_.empty()) {
564       // Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
565       // Remove this node from its parent's child
566       parent_->children_.clear();
567     } else if (children_.size() > 1) {
568       // Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and
569       //         the node's children become its parent's children.
570       // Remove this node from its parent's child
571       my_parent->children_.clear();
572       // Move its child to be its parent's child
573       for (auto &my_child : children_) {
574         my_parent->children_.push_back(my_child);
575         my_child->parent_ = my_parent;
576       }
577     }
578     // And mark itself as an orphan
579     parent_ = nullptr;
580     children_.clear();
581   } else if (children_.empty() && parent_->children_.size() > 1) {
582     // Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will
583     //         be squeezed left.
584     auto parent = parent_;
585     // Remove this node from its parent's child
586     parent->children_.erase(std::remove(parent->children_.begin(), parent->children_.end(), shared_from_this()),
587                             parent->children_.end());  // removal using "erase remove idiom"
588     // And mark itself as an orphan
589     parent_ = nullptr;
590     children_.clear();
591   } else if (children_.size() == 1 && parent_->children_.size() > 1) {
592     // Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
593     //         children become its parent's children.
594     auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
595     CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
596     *itr = children_[0];              // replace this node in its parent's children list with its single child
597     children_[0]->parent_ = parent_;  // set its single child's parent ptr to its parent
598     // And mark itself as an orphan
599     parent_ = nullptr;
600     children_.clear();
601   } else {
602     RETURN_STATUS_UNEXPECTED("Internal error: we should not reach here.");
603   }
604   return Status::OK();
605 }
606 
607 // In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Accept(IRNodePass * const p,bool * const modified)608 Status DatasetNode::Accept(IRNodePass *const p, bool *const modified) {
609   // This method will only be called if its derived class does not implement one.
610   return p->Visit(shared_from_this(), modified);
611 }
612 
613 // In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
614 // after all child nodes are visited.
AcceptAfter(IRNodePass * const p,bool * const modified)615 Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
616   // This method will only be called if its derived class does not implement one.
617   return p->VisitAfter(shared_from_this(), modified);
618 }
619 
GetShardId(int32_t * const shard_id)620 Status DatasetNode::GetShardId(int32_t *const shard_id) {
621   if (children_.size() == 1) {
622     // Get shard id from the child node
623     return children_[0]->GetShardId(shard_id);
624   } else if (children_.size() > 1) {
625     // It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case.
626     // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
627     // always be in front of the child_ structure, so we get the dataset size from the last child.
628     return children_.back()->GetShardId(shard_id);
629   } else {
630     LOG_AND_RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
631   }
632 }
633 
634 // Gets the dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)635 Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
636                                    int64_t *dataset_size) {
637   if (dataset_size_ > 0) {
638     *dataset_size = dataset_size_;
639     return Status::OK();
640   }
641   if (!IsSizeDefined() && size_getter != nullptr) {
642     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
643     dataset_size_ = *dataset_size;
644     return Status::OK();
645   }
646   if (children_.size() == 1) {
647     return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size);
648   } else if (children_.size() > 1) {
649     // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
650     // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
651     // always be in front of the child_ structure, so we get the dataset size from the last child.
652     return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size);
653   } else {
654     RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
655   }
656 }
ValidateParams()657 Status DatasetNode::ValidateParams() {
658   int32_t num_threads = GlobalContext::config_manager()->num_cpu_threads();
659   // in case std::thread::hardware_concurrency returns 0, use an artificial upper limit
660   num_threads = num_threads > 0 ? num_threads : std::numeric_limits<uint16_t>::max();
661   CHECK_FAIL_RETURN_UNEXPECTED(
662     num_workers_ > 0 && num_workers_ <= num_threads,
663     Name() + "'s num_workers=" + std::to_string(num_workers_) +
664       ", this value is not within the required range of [1, cpu_thread_cnt=" + std::to_string(num_threads) + "].");
665   return Status::OK();
666 }
667 
to_json(nlohmann::json * out_json)668 Status DatasetNode::to_json(nlohmann::json *out_json) {
669   nlohmann::json args;
670   args["num_parallel_workers"] = num_workers_;
671   args["connector_queue_size"] = connector_que_size_;
672   *out_json = args;
673   return Status::OK();
674 }
675 
Accept(IRNodePass * const p,bool * const modified)676 Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
677   return p->Visit(shared_from_base<MappableSourceNode>(), modified);
678 }
679 
Accept(IRNodePass * const p,bool * const modified)680 Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
681   return p->Visit(shared_from_base<NonMappableSourceNode>(), modified);
682 }
683 
684 }  // namespace dataset
685 }  // namespace mindspore
686