• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
18 
19 #include <algorithm>
20 #include <limits>
21 #include <memory>
22 #include <set>
23 
24 #include "minddata/dataset/engine/opt/pass.h"
25 #include "minddata/dataset/util/random.h"
26 #include "minddata/dataset/util/status.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 
31 // Helper function to compute a default shuffle size
ComputeShuffleSize(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int64_t * shuffle_size)32 Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
33                           int64_t *shuffle_size) {
34   RETURN_UNEXPECTED_IF_NULL(shuffle_size);
35   const int64_t average_files_multiplier = 4;
36   const int64_t shuffle_max = 10000;
37   int64_t avg_rows_per_file = 0;
38 
39   // Adjust the num rows per shard if sharding was given
40   if (num_devices > 0) {
41     if (num_rows % num_devices == 0) {
42       num_rows = num_rows / num_devices;
43     } else {
44       num_rows = (num_rows / num_devices) + 1;
45     }
46   }
47 
48   // Cap based on total rows directive.  Some ops do not have this and give value of 0.
49   if (total_rows > 0) {
50     num_rows = std::min(num_rows, total_rows);
51   }
52 
53   // get the average per file
54   CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0.");
55   avg_rows_per_file = num_rows / num_files;
56 
57   *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
58   return Status::OK();
59 }
60 
61 // Helper function to inject a shuffle operator over top of current operator being built
AddShuffleOp(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int32_t connector_que_size,std::shared_ptr<DatasetOp> * shuffle_op)62 Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
63                     int32_t connector_que_size, std::shared_ptr<DatasetOp> *shuffle_op) {
64   RETURN_UNEXPECTED_IF_NULL(shuffle_op);
65   int64_t shuffle_size = 0;
66   RETURN_IF_NOT_OK(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
67   MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
68   // Add the shuffle op
69   *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true);
70   return Status::OK();
71 }
72 
73 // Helper function to validate dataset directory parameter
ValidateDatasetDirParam(const std::string & dataset_name,std::string dataset_dir)74 Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
75   if (dataset_dir.empty()) {
76     std::string err_msg = dataset_name + ": dataset_dir is not specified.";
77     MS_LOG(ERROR) << err_msg;
78     RETURN_STATUS_SYNTAX_ERROR(err_msg);
79   }
80 
81   std::string real_path;
82   RETURN_IF_NOT_OK(Path::RealPath(dataset_dir, real_path));
83   Path dir(dataset_dir);
84   if (!dir.IsDirectory()) {
85     std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
86     MS_LOG(ERROR) << err_msg;
87     RETURN_STATUS_SYNTAX_ERROR(err_msg);
88   }
89 
90   if (access(dataset_dir.c_str(), R_OK) == -1) {
91     std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
92     MS_LOG(ERROR) << err_msg;
93     RETURN_STATUS_SYNTAX_ERROR(err_msg);
94   }
95 
96   return Status::OK();
97 }
98 
99 // Helper function to validate dataset files parameter
ValidateDatasetFilesParam(const std::string & dataset_name,const std::vector<std::string> & dataset_files)100 Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
101   if (dataset_files.empty()) {
102     std::string err_msg = dataset_name + ": dataset_files is not specified.";
103     MS_LOG(ERROR) << err_msg;
104     RETURN_STATUS_SYNTAX_ERROR(err_msg);
105   }
106 
107   for (auto f : dataset_files) {
108     Path dataset_file(f);
109     if (!dataset_file.Exists()) {
110       std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist.";
111       MS_LOG(ERROR) << err_msg;
112 
113       RETURN_STATUS_SYNTAX_ERROR(err_msg);
114     }
115     if (access(dataset_file.ToString().c_str(), R_OK) == -1) {
116       std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
117       MS_LOG(ERROR) << err_msg;
118       RETURN_STATUS_SYNTAX_ERROR(err_msg);
119     }
120   }
121 
122   return Status::OK();
123 }
124 
125 // Helper function to validate dataset num_shards and shard_id parameters
ValidateDatasetShardParams(const std::string & dataset_name,int32_t num_shards,int32_t shard_id)126 Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
127   if (num_shards <= 0) {
128     std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
129     MS_LOG(ERROR) << err_msg;
130     RETURN_STATUS_SYNTAX_ERROR(err_msg);
131   }
132 
133   if (shard_id < 0 || shard_id >= num_shards) {
134     // num_shards
135     std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
136                           ", num_shards: " + std::to_string(num_shards);
137     MS_LOG(ERROR) << err_msg;
138     RETURN_STATUS_SYNTAX_ERROR(err_msg);
139   }
140 
141   return Status::OK();
142 }
143 
144 // Helper function to validate dataset sampler parameter
ValidateDatasetSampler(const std::string & dataset_name,const std::shared_ptr<SamplerObj> & sampler)145 Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
146   if (sampler == nullptr) {
147     std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
148     MS_LOG(ERROR) << err_msg;
149     RETURN_STATUS_SYNTAX_ERROR(err_msg);
150   }
151   RETURN_IF_NOT_OK(sampler->ValidateParams());
152   return Status::OK();
153 }
154 
ValidateStringValue(const std::string & dataset_name,const std::string & str,const std::unordered_set<std::string> & valid_strings)155 Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
156                            const std::unordered_set<std::string> &valid_strings) {
157   if (valid_strings.find(str) == valid_strings.end()) {
158     std::string init;
159     std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init,
160                                        [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
161     std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
162     MS_LOG(ERROR) << err_msg;
163     RETURN_STATUS_SYNTAX_ERROR(err_msg);
164   }
165   return Status::OK();
166 }
167 
168 // Helper function to validate dataset input/output column parameter
ValidateDatasetColumnParam(const std::string & dataset_name,const std::string & column_param,const std::vector<std::string> & columns)169 Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
170                                   const std::vector<std::string> &columns) {
171   if (columns.empty()) {
172     std::string err_msg = dataset_name + ":" + column_param + " should not be empty string";
173     MS_LOG(ERROR) << err_msg;
174     RETURN_STATUS_SYNTAX_ERROR(err_msg);
175   }
176   for (uint32_t i = 0; i < columns.size(); ++i) {
177     if (columns[i].empty()) {
178       std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty";
179       MS_LOG(ERROR) << err_msg;
180       RETURN_STATUS_SYNTAX_ERROR(err_msg);
181     }
182   }
183   std::set<std::string> columns_set;
184   for (auto &column_name : columns) {
185     auto result = columns_set.insert(column_name);
186     if (result.second == false) {
187       std::string err_msg = dataset_name + ":" + column_param +
188                             ": Invalid parameter, duplicate column names are not allowed: " + *result.first;
189       MS_LOG(ERROR) << err_msg;
190       RETURN_STATUS_SYNTAX_ERROR(err_msg);
191     }
192   }
193   return Status::OK();
194 }
195 
SelectSampler(int64_t num_samples,bool shuffle,int32_t num_shards,int32_t shard_id)196 std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
197   if (shuffle) {
198     if (num_shards > 1) {
199       // If shuffle enabled, sharding enabled, use distributed random sampler
200       return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
201     }
202     // If shuffle enabled, sharding disabled, use random sampler
203     return RandomSampler(num_samples >= 0, num_samples).Parse();
204   }
205   if (num_shards > 1) {
206     // If shuffle disabled, sharding enabled, use distributed sequential sampler
207     return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
208   }
209   // If shuffle disabled, sharding disabled, use sequential sampler
210   return SequentialSampler(0, num_samples).Parse();
211 }
212 
213 // Constructor to initialize the cache
DatasetNode(const std::shared_ptr<DatasetCache> & dataset_cache)214 DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }
215 
SetNumWorkers(int32_t num_workers)216 std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
217   num_workers_ = num_workers;
218   return shared_from_this();
219 }
220 
SetDatasetCache(const std::shared_ptr<DatasetCache> & cache)221 std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) {
222   cache_ = cache;
223   return shared_from_this();
224 }
225 
DatasetNode()226 DatasetNode::DatasetNode()
227     : cache_(nullptr),
228       parent_(nullptr),
229       children_({}),
230       dataset_size_(-1),
231       mappable_(kNotADataSource),
232       nary_op_(false),
233       descendant_of_cache_(false),
234       total_repeats_(-1),
235       num_epochs_(1) {
236   // Fetch some default value from config manager
237   std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
238   num_workers_ = cfg->num_parallel_workers();
239   connector_que_size_ = cfg->op_connector_size();
240   worker_connector_size_ = cfg->worker_connector_size();
241 }
242 
PrintColumns(const std::vector<std::string> & columns) const243 std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
244   std::string me;
245   if (columns.empty()) {
246     me = "<nil>";
247   } else {
248     me = "[";
249     auto i = 0;
250     for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
251       me += *it;
252       if (i < columns.size() - 1) {
253         me += ", ";
254       } else {
255         me += "]";
256       }
257     }
258   }
259   return me;
260 }
261 
PrintTree(std::ostream & out) const262 void DatasetNode::PrintTree(std::ostream &out) const {
263   int level = 0;
264   PrintNode(out, &level);
265 }
266 
PrintNode(std::ostream & out,int * level) const267 void DatasetNode::PrintNode(std::ostream &out, int *level) const {
268   const std::string prefix = "+-";
269   const std::string indent = "| ";
270   out << prefix;
271   Print(out);
272   for (const auto &c : this->Children()) {
273     out << '\n';
274     ++(*level);
275     for (auto i = 0; i < *level; i++) {
276       out << indent;
277     }
278     c->PrintNode(out, level);
279     --(*level);
280   }
281 }
282 
283 // Add a node as a child, node's parent needs to be empty
284 // This function will allow child to be a nullptr, in which case it will simply skip.
285 // This function is used only when building IR node one by one from parsing the user code.
286 // During the parsing, we allow a node to have more than one parent, possibly forming a graph.
287 // It does not maintain the parent_ attribute of the node, which enforces a single parent and a tree structure.
AddChild(std::shared_ptr<DatasetNode> child)288 void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
289   if (child != nullptr) {
290     children_.push_back(child);
291   }
292 }
293 
294 /*
295  * AppendChild(<node>) appending <node> as the last child of this node. The new node must have no parent.
296  *
297  * Input tree:
298  *      ds4
299  *     /   \
300  *   ds3   ds2
301  *     |
302  *    ds1
303  *
304  * ds4->AppendChild(ds6) yields this tree
305  *
306  *      _ ds4 _
307  *     /   |   \
308  *   ds3  ds2  ds6
309  *    |
310  *   ds1
311  *
312  */
AppendChild(std::shared_ptr<DatasetNode> child)313 Status DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
314   CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
315   CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
316                                "This node must be a unary operator with no child or an n-ary operator");
317   children_.push_back(child);
318   child->parent_ = this;
319   return Status::OK();
320 }
321 
322 /*
323  * InsertChildAt(<pos>, <node>) inserts the <node> to be at the <pos> index of the vector of its child nodes.
324  * As in the convention of C++, <pos> starts at position 0.
325  * If the <pos> is a negative number or larger than the size of the vector minus one, an error is raised.
326  */
InsertChildAt(int32_t pos,std::shared_ptr<DatasetNode> child)327 Status DatasetNode::InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> child) {
328   CHECK_FAIL_RETURN_UNEXPECTED(pos > -1 && pos <= children_.size(), "Position must in the range of [0, size]");
329   CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
330   CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
331                                "This node must be a unary operator with no child or an n-ary operator");
332   children_.insert(children_.begin() + pos, child);
333   child->parent_ = this;
334   return Status::OK();
335 }
336 
337 /*
338  * Insert the input <node> above this node
339  * Input tree:
340  *       ds4
341  *      /   \
342  *     ds3  ds2
343  *      |
344  *     ds1
345  *
346  * Case 1: If we want to insert a new node ds5 between ds4 and ds3, use
347  *           ds3->InsertAbove(ds5)
348  *
349  *       ds4
350  *      /   \
351  *     ds5  ds2
352  *      |
353  *     ds3
354  *      |
355  *     ds1
356  *
357  * Case 2: Likewise, ds2->InsertAbove(ds6) yields
358  *
359  *       ds4
360  *      /   \
361  *     ds3  ds6
362  *      |    |
363  *     ds1  ds2
364  *
365  * Case 3: We can insert a new node between ds3 and ds1 by ds1->InsertAbove(ds7)
366  *
367  *       ds4
368  *      /   \
369  *     ds3  ds2
370  *      |
371  *     ds7
372  *      |
373  *     ds1
374  *
375  * InsertAbove() cannot use on the root node of a tree.
376  */
InsertAbove(std::shared_ptr<DatasetNode> node)377 Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
378   CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(node), "Node to insert must be an orphan node.");
379   CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must not be the root or a node without parent.");
380   auto parent = parent_;
381 
382   // The following fields of these three nodes are changed in this function:
383   // 1. parent->children_
384   // 2. node->parent_ and node->children_
385   // 3. this->parent_
386   auto current_node_itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
387   *current_node_itr = node;  // replace me in my parent's children list with the newly inserted node
388   node->parent_ = parent;    // set the newly inserted node's parent ptr to my parent
389   node->children_.push_back(shared_from_this());  // add myself to the newly inserted node's children list
390   parent_ = node.get();                           // set my parent ptr to the newly inserted node
391 
392   return Status::OK();
393 }
394 
395 /*
396  * Drop() detaches this node from the tree it is in. Calling Drop() from a standalone node is a no-op.
397  *
398  * Input tree:
399  *       ds10
400  *      /    \
401  *    ds9    ds6
402  *     |   /  |  \
403  *    ds8 ds5 ds4 ds1
404  *     |     /  \
405  *    ds7  ds3  ds2
406  *
407  * Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
408  *
409  *   ds7->Drop() yields the tree below:
410  *
411  *       ds10
412  *      /    \
413  *    ds9    ds6
414  *     |   /  |  \
415  *    ds8 ds5 ds4 ds1
416  *           /  \
417  *         ds3  ds2
418  *
419  * Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
420  *         becomes its parent's child.
421  *
422  *   ds8->Drop() yields the tree below:
423  *
424  *       ds10
425  *      /    \
426  *    ds9    ds6
427  *     |   /  |  \
428  *    ds7 ds5 ds4 ds1
429  *           /  \
430  *         ds3  ds2
431  *
432  * Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and the node's
433  *         children become its parent's children.
434  *
435  *   When the input tree is
436  *
437  *       ds10
438  *      /    \
439  *    ds9    ds6
440  *     |      |
441  *    ds8    ds4
442  *     |    /   \
443  *    ds7  ds3  ds2
444  *
445  *    ds4->Drop() yields the tree below:
446  *
447  *       ds10
448  *      /    \
449  *    ds9    ds6
450  *     |     /  \
451  *    ds8  ds3  ds2
452  *     |
453  *    ds7
454  *
455  *   But if ds6 is not an n-ary operator, ds4->Drop() will raise an error because we cannot add the children of an
456  *   n-ary operator (ds4) to a unary operator (ds6).
457  *
458  * Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will be
459  *         squeezed left.
460  *
461  * Input tree:
462  *       ds10
463  *      /    \
464  *    ds9    ds6
465  *     |   /  |  \
466  *    ds8 ds5 ds4 ds1
467  *     |     /  \
468  *    ds7  ds3  ds2
469  *
470  *   ds5->Drop() yields the tree below:
471  *
472  *       ds10
473  *      /    \
474  *    ds9    ds6
475  *     |     /  \
476  *    ds8   ds4 ds1
477  *     |    /  \
478  *    ds7 ds3  ds2
479  *
480  * Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
481  *         children become its parent's children.
482  *
483  * Input tree:
484  *       ds10
485  *      /    \
486  *    ds9    ds6
487  *     |   /  |  \
488  *    ds8 ds5 ds4 ds1
489  *     |      |
490  *    ds7     ds3
491  *
492  *   ds4->Drop() yields the tree below:
493  *
494  *       ds10
495  *      /    \
496  *    ds9    ds6
497  *     |   /  |  \
498  *    ds8 ds5 ds3 ds1
499  *     |
500  *    ds7
501  *
502  * Case 6: When the node has more than one child and more than one sibling, Drop() will raise an error.
503  *         If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it
504  *         with a combination of Drop(), InsertChildAt()
505  *
506  * Input tree:
507  *       ds10
508  *      /    \
509  *    ds9    ds6
510  *     |   /  |  \
511  *    ds8 ds5 ds4 ds1
512  *     |     /  \
513  *    ds7  ds3  ds2
514  *
515  * If we want to form this tree below:
516  *
517  *       ds10
518  *      /    \
519  *    ds9    ds6_____
520  *     |   /  |   |  \
521  *    ds8 ds5 ds3 ds2 ds1
522  *     |
523  *    ds7
524  *
525  */
Drop()526 Status DatasetNode::Drop() {
527   CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node to drop must not be the root or a node without parent.");
528   CHECK_FAIL_RETURN_UNEXPECTED(!(IsNaryOperator() && parent_->IsUnaryOperator()),
529                                "Trying to drop an n-ary operator that is a child of a unary operator");
530   CHECK_FAIL_RETURN_UNEXPECTED(!(children_.size() > 1 && parent_->children_.size() > 1),
531                                "This node to drop must not have more than one child and more than one sibling.");
532   if (parent_->children_.size() == 1) {
533     auto my_parent = parent_;
534     // Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
535     //         becomes its parent's child.
536     // This is the most common use case.
537     if (children_.size() == 1) {
538       auto child = children_[0];
539       // Move its child to be its parent's child
540       my_parent->children_[0] = child;
541       child->parent_ = my_parent;
542     } else if (children_.empty()) {
543       // Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
544       // Remove this node from its parent's child
545       parent_->children_.clear();
546     } else if (children_.size() > 1) {
547       // Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and
548       //         the node's children become its parent's children.
549       // Remove this node from its parent's child
550       my_parent->children_.clear();
551       // Move its child to be its parent's child
552       for (auto &my_child : children_) {
553         my_parent->children_.push_back(my_child);
554         my_child->parent_ = my_parent;
555       }
556     }
557     // And mark itself as an orphan
558     parent_ = nullptr;
559     children_.clear();
560   } else if (children_.empty() && parent_->children_.size() > 1) {
561     // Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will
562     //         be squeezed left.
563     auto parent = parent_;
564     // Remove this node from its parent's child
565     parent->children_.erase(std::remove(parent->children_.begin(), parent->children_.end(), shared_from_this()),
566                             parent->children_.end());  // removal using "erase remove idiom"
567     // And mark itself as an orphan
568     parent_ = nullptr;
569     children_.clear();
570   } else if (children_.size() == 1 && parent_->children_.size() > 1) {
571     // Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
572     //         children become its parent's children.
573     auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
574     CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
575     *itr = children_[0];              // replace this node in its parent's children list with its single child
576     children_[0]->parent_ = parent_;  // set its single child's parent ptr to its parent
577     // And mark itself as an orphan
578     parent_ = nullptr;
579     children_.clear();
580   } else {
581     RETURN_STATUS_UNEXPECTED("Internal error: we should not reach here.");
582   }
583   return Status::OK();
584 }
585 
586 // In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Accept(IRNodePass * const p,bool * const modified)587 Status DatasetNode::Accept(IRNodePass *const p, bool *const modified) {
588   // This method will only be called if its derived class does not implement one.
589   return p->Visit(shared_from_this(), modified);
590 }
591 
592 // In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
593 // after all child nodes are visited.
AcceptAfter(IRNodePass * const p,bool * const modified)594 Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
595   // This method will only be called if its derived class does not implement one.
596   return p->VisitAfter(shared_from_this(), modified);
597 }
598 
GetShardId(int32_t * const shard_id)599 Status DatasetNode::GetShardId(int32_t *const shard_id) {
600   if (children_.size() == 1) {
601     // Get shard id from the child node
602     return children_[0]->GetShardId(shard_id);
603   } else if (children_.size() > 1) {
604     // It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case.
605     // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
606     // always be in front of the child_ structure, so we get the dataset size from the last child.
607     return children_.back()->GetShardId(shard_id);
608   } else {
609     RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
610   }
611 }
612 
613 // Gets the dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)614 Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
615                                    int64_t *dataset_size) {
616   if (dataset_size_ > 0) {
617     *dataset_size = dataset_size_;
618     return Status::OK();
619   }
620   if (!IsSizeDefined()) {
621     RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
622     dataset_size_ = *dataset_size;
623     return Status::OK();
624   }
625   if (children_.size() == 1) {
626     return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size);
627   } else if (children_.size() > 1) {
628     // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
629     // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
630     // always be in front of the child_ structure, so we get the dataset size from the last child.
631     return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size);
632   } else {
633     RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
634   }
635 }
ValidateParams()636 Status DatasetNode::ValidateParams() {
637   int32_t num_threads = GlobalContext::config_manager()->num_cpu_threads();
638   // in case std::thread::hardware_concurrency returns 0, use an artificial upper limit
639   num_threads = num_threads > 0 ? num_threads : std::numeric_limits<uint16_t>::max();
640   CHECK_FAIL_RETURN_UNEXPECTED(
641     num_workers_ > 0 && num_workers_ <= num_threads,
642     Name() + "'s num_workers=" + std::to_string(num_workers_) +
643       ", this value is not within the required range of [1, cpu_thread_cnt=" + std::to_string(num_threads) + "].");
644   return Status::OK();
645 }
646 
to_json(nlohmann::json * out_json)647 Status DatasetNode::to_json(nlohmann::json *out_json) {
648   nlohmann::json args;
649   args["num_parallel_workers"] = num_workers_;
650   *out_json = args;
651   return Status::OK();
652 }
653 
Accept(IRNodePass * const p,bool * const modified)654 Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
655   return p->Visit(shared_from_base<MappableSourceNode>(), modified);
656 }
657 
Accept(IRNodePass * const p,bool * const modified)658 Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
659   return p->Visit(shared_from_base<NonMappableSourceNode>(), modified);
660 }
661 
662 }  // namespace dataset
663 }  // namespace mindspore
664