1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
18
19 #include <algorithm>
20 #include <limits>
21 #include <memory>
22 #include <set>
23 #include <utility>
24
25 #include "minddata/dataset/engine/opt/pass.h"
26 #include "minddata/dataset/util/random.h"
27 #include "minddata/dataset/util/status.h"
28
29 namespace mindspore {
30 namespace dataset {
31
32 // Helper function to compute a default shuffle size
ComputeShuffleSize(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int64_t * shuffle_size)33 Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
34 int64_t *shuffle_size) {
35 RETURN_UNEXPECTED_IF_NULL(shuffle_size);
36 const int64_t average_files_multiplier = 4;
37 const int64_t shuffle_max = 10000;
38
39 // Adjust the num rows per shard if sharding was given
40 if (num_devices > 0) {
41 if (num_rows % num_devices == 0) {
42 num_rows = num_rows / num_devices;
43 } else {
44 num_rows = (num_rows / num_devices) + 1;
45 }
46 }
47
48 // Cap based on total rows directive. Some ops do not have this and give value of 0.
49 if (total_rows > 0) {
50 num_rows = std::min(num_rows, total_rows);
51 }
52
53 // get the average per file
54 CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0.");
55 int64_t avg_rows_per_file = num_rows / num_files;
56
57 if (avg_rows_per_file != 0) {
58 *shuffle_size = std::min(avg_rows_per_file * average_files_multiplier, shuffle_max);
59 } else {
60 *shuffle_size = shuffle_max;
61 }
62
63 return Status::OK();
64 }
65
66 // Helper function to inject a shuffle operator over top of current operator being built
AddShuffleOp(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int32_t connector_que_size,std::shared_ptr<ShuffleOp> * shuffle_op)67 Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
68 int32_t connector_que_size, std::shared_ptr<ShuffleOp> *shuffle_op) {
69 RETURN_UNEXPECTED_IF_NULL(shuffle_op);
70 int64_t shuffle_size = 0;
71 RETURN_IF_NOT_OK(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
72 MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
73 // Add the shuffle op
74 *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true);
75 return Status::OK();
76 }
77
78 // Helper function to validate dataset directory parameter
ValidateDatasetDirParam(const std::string & dataset_name,std::string dataset_dir)79 Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
80 if (dataset_dir.empty()) {
81 std::string err_msg = dataset_name + ": dataset_dir is not specified.";
82 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
83 }
84
85 std::string real_path;
86 RETURN_IF_NOT_OK(Path::RealPath(dataset_dir, real_path));
87 Path dir(dataset_dir);
88 if (!dir.IsDirectory()) {
89 std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
90 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
91 }
92
93 if (access(dataset_dir.c_str(), R_OK) == -1) {
94 std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
95 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
96 }
97
98 return Status::OK();
99 }
100
101 // Helper function to validate dataset files parameter
ValidateDatasetFilesParam(const std::string & dataset_name,const std::vector<std::string> & dataset_files,const std::string & file_name)102 Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files,
103 const std::string &file_name) {
104 if (dataset_files.empty()) {
105 std::string err_msg = dataset_name + ": dataset files list is empty, check input dataset_files.";
106 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
107 }
108
109 for (const auto &f : dataset_files) {
110 Path dataset_file(f);
111 if (!dataset_file.Exists()) {
112 std::string err_msg = dataset_name + ": " + file_name + ": [" + f + "] is invalid or does not exist.";
113 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
114 }
115 if (access(dataset_file.ToString().c_str(), R_OK) == -1) {
116 std::string err_msg = dataset_name + ": No access to specified " + file_name + ": " + f;
117 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
118 }
119 }
120
121 return Status::OK();
122 }
123
124 // Helper function to validate dataset num_shards and shard_id parameters
ValidateDatasetShardParams(const std::string & dataset_name,int32_t num_shards,int32_t shard_id)125 Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
126 if (num_shards <= 0) {
127 std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
128 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
129 }
130
131 if (shard_id < 0 || shard_id >= num_shards) {
132 // num_shards
133 std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
134 ", num_shards: " + std::to_string(num_shards);
135 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
136 }
137
138 return Status::OK();
139 }
140
141 // Helper function to validate dataset sampler parameter
ValidateDatasetSampler(const std::string & dataset_name,const std::shared_ptr<SamplerObj> & sampler)142 Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
143 if (sampler == nullptr) {
144 std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
145 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
146 }
147 RETURN_IF_NOT_OK(sampler->ValidateParams());
148 return Status::OK();
149 }
150
ValidateStringValue(const std::string & dataset_name,const std::string & str,const std::unordered_set<std::string> & valid_strings)151 Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
152 const std::unordered_set<std::string> &valid_strings) {
153 if (valid_strings.find(str) == valid_strings.end()) {
154 std::string init;
155 std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init,
156 [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
157 std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
158 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
159 }
160 return Status::OK();
161 }
162
163 // Helper function to validate dataset input/output column parameter
ValidateDatasetColumnParam(const std::string & dataset_name,const std::string & column_param,const std::vector<std::string> & columns)164 Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
165 const std::vector<std::string> &columns) {
166 if (columns.empty()) {
167 std::string err_msg = dataset_name + ": '" + column_param + "' should not be empty string";
168 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
169 }
170 for (uint32_t i = 0; i < columns.size(); ++i) {
171 if (columns[i].empty()) {
172 std::string err_msg = dataset_name + ": '" + column_param + "' [" + std::to_string(i) + "] must not be empty";
173 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
174 }
175 }
176 std::set<std::string> columns_set;
177 for (auto &column_name : columns) {
178 auto result = columns_set.insert(column_name);
179 if (result.second == false) {
180 std::string err_msg = dataset_name + ": '" + column_param +
181 "' : Invalid parameter, duplicate column names are not allowed: " + *result.first;
182 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
183 }
184 }
185 return Status::OK();
186 }
187
ValidateMapKey(const std::string & dataset_name,const std::string & key,const std::map<std::string,std::vector<std::string>> & map)188 Status ValidateMapKey(const std::string &dataset_name, const std::string &key,
189 const std::map<std::string, std::vector<std::string>> &map) {
190 if (map.find(key) == map.end()) {
191 std::string init;
192 std::string mode = std::accumulate(map.begin(), map.end(), init,
193 [](std::string a, auto b) { return std::move(a) + " " + std::move(b.first); });
194 std::string err_msg = dataset_name + ": " + key + " does not match any key in [" + mode + " ]";
195 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
196 }
197 return Status::OK();
198 }
199
ValidateMapValue(const std::string & dataset_name,const std::string & str,const std::vector<std::string> & valid_strings)200 Status ValidateMapValue(const std::string &dataset_name, const std::string &str,
201 const std::vector<std::string> &valid_strings) {
202 if (find(valid_strings.begin(), valid_strings.end(), str) == valid_strings.end()) {
203 std::string init;
204 std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init,
205 [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
206 std::string err_msg = dataset_name + ": " + str + " does not match any string in [" + mode + " ]";
207 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
208 }
209 return Status::OK();
210 }
211
SelectSampler(int64_t num_samples,bool shuffle,int32_t num_shards,int32_t shard_id)212 std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
213 if (shuffle) {
214 if (num_shards > 1) {
215 // If shuffle enabled, sharding enabled, use distributed random sampler
216 return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
217 }
218 // If shuffle enabled, sharding disabled, use random sampler
219 return RandomSampler(num_samples >= 0, num_samples).Parse();
220 }
221 if (num_shards > 1) {
222 // If shuffle disabled, sharding enabled, use distributed sequential sampler
223 return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
224 }
225 // If shuffle disabled, sharding disabled, use sequential sampler
226 return SequentialSampler(0, num_samples).Parse();
227 }
228
229 // Constructor to initialize the cache
DatasetNode(const std::shared_ptr<DatasetCache> & dataset_cache)230 DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }
231
SetNumWorkers(int32_t num_workers)232 std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
233 num_workers_ = num_workers;
234 return shared_from_this();
235 }
236
SetConnectorQueueSize(int32_t connector_queue_size)237 std::shared_ptr<DatasetNode> DatasetNode::SetConnectorQueueSize(int32_t connector_queue_size) {
238 connector_que_size_ = connector_queue_size;
239 return shared_from_this();
240 }
241
SetDatasetCache(const std::shared_ptr<DatasetCache> & cache)242 std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) {
243 cache_ = cache;
244 return shared_from_this();
245 }
246
DatasetNode()247 DatasetNode::DatasetNode()
248 : cache_(nullptr),
249 parent_(nullptr),
250 children_({}),
251 dataset_size_(-1),
252 mappable_(kNotADataSource),
253 nary_op_(false),
254 descendant_of_cache_(false),
255 total_repeats_(-1),
256 num_epochs_(1) {
257 // Fetch some default value from config manager
258 std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
259 num_workers_ = cfg->num_parallel_workers();
260 connector_que_size_ = cfg->op_connector_size();
261 worker_connector_size_ = cfg->worker_connector_size();
262 }
263
PrintColumns(const std::vector<std::string> & columns) const264 std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
265 std::string me;
266 if (columns.empty()) {
267 me = "<nil>";
268 } else {
269 me = "[";
270 auto i = 0;
271 for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
272 me += *it;
273 if (i < columns.size() - 1) {
274 me += ", ";
275 } else {
276 me += "]";
277 }
278 }
279 }
280 return me;
281 }
282
PrintTree(std::ostream & out) const283 void DatasetNode::PrintTree(std::ostream &out) const {
284 int level = 0;
285 PrintNode(out, &level);
286 }
287
PrintNode(std::ostream & out,int * level) const288 void DatasetNode::PrintNode(std::ostream &out, int *level) const {
289 const std::string prefix = "+-";
290 const std::string indent = "| ";
291 out << prefix;
292 Print(out);
293 for (const auto &c : this->Children()) {
294 out << '\n';
295 ++(*level);
296 for (auto i = 0; i < *level; i++) {
297 out << indent;
298 }
299 c->PrintNode(out, level);
300 --(*level);
301 }
302 }
303
304 // Add a node as a child, node's parent needs to be empty
305 // This function will allow child to be a nullptr, in which case it will simply skip.
306 // This function is used only when building IR node one by one from parsing the user code.
307 // During the parsing, we allow a node to have more than one parent, possibly forming a graph.
308 // It does not maintain the parent_ attribute of the node, which enforces a single parent and a tree structure.
AddChild(std::shared_ptr<DatasetNode> child)309 void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
310 if (child != nullptr) {
311 children_.push_back(child);
312 }
313 }
314
315 /*
316 * AppendChild(<node>) appending <node> as the last child of this node. The new node must have no parent.
317 *
318 * Input tree:
319 * ds4
320 * / \
321 * ds3 ds2
322 * |
323 * ds1
324 *
325 * ds4->AppendChild(ds6) yields this tree
326 *
327 * _ ds4 _
328 * / | \
329 * ds3 ds2 ds6
330 * |
331 * ds1
332 *
333 */
AppendChild(std::shared_ptr<DatasetNode> child)334 Status DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
335 CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
336 CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
337 "This node must be a unary operator with no child or an n-ary operator");
338 children_.push_back(child);
339 child->parent_ = this;
340 return Status::OK();
341 }
342
343 /*
344 * InsertChildAt(<pos>, <node>) inserts the <node> to be at the <pos> index of the vector of its child nodes.
345 * As in the convention of C++, <pos> starts at position 0.
346 * If the <pos> is a negative number or larger than the size of the vector minus one, an error is raised.
347 */
InsertChildAt(int32_t pos,std::shared_ptr<DatasetNode> child)348 Status DatasetNode::InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> child) {
349 CHECK_FAIL_RETURN_UNEXPECTED(pos > -1 && pos <= children_.size(), "Position must in the range of [0, size]");
350 CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
351 CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
352 "This node must be a unary operator with no child or an n-ary operator");
353 children_.insert(children_.begin() + pos, child);
354 child->parent_ = this;
355 return Status::OK();
356 }
357
358 /*
359 * Insert the input <node> above this node
360 * Input tree:
361 * ds4
362 * / \
363 * ds3 ds2
364 * |
365 * ds1
366 *
367 * Case 1: If we want to insert a new node ds5 between ds4 and ds3, use
368 * ds3->InsertAbove(ds5)
369 *
370 * ds4
371 * / \
372 * ds5 ds2
373 * |
374 * ds3
375 * |
376 * ds1
377 *
378 * Case 2: Likewise, ds2->InsertAbove(ds6) yields
379 *
380 * ds4
381 * / \
382 * ds3 ds6
383 * | |
384 * ds1 ds2
385 *
386 * Case 3: We can insert a new node between ds3 and ds1 by ds1->InsertAbove(ds7)
387 *
388 * ds4
389 * / \
390 * ds3 ds2
391 * |
392 * ds7
393 * |
394 * ds1
395 *
396 * InsertAbove() cannot use on the root node of a tree.
397 */
InsertAbove(std::shared_ptr<DatasetNode> node)398 Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
399 CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(node), "Node to insert must be an orphan node.");
400 CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must not be the root or a node without parent.");
401 auto parent = parent_;
402
403 // The following fields of these three nodes are changed in this function:
404 // 1. parent->children_
405 // 2. node->parent_ and node->children_
406 // 3. this->parent_
407 auto current_node_itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
408 *current_node_itr = node; // replace me in my parent's children list with the newly inserted node
409 node->parent_ = parent; // set the newly inserted node's parent ptr to my parent
410 node->children_.push_back(shared_from_this()); // add myself to the newly inserted node's children list
411 parent_ = node.get(); // set my parent ptr to the newly inserted node
412
413 return Status::OK();
414 }
415
416 /*
417 * Drop() detaches this node from the tree it is in. Calling Drop() from a standalone node is a no-op.
418 *
419 * Input tree:
420 * ds10
421 * / \
422 * ds9 ds6
423 * | / | \
424 * ds8 ds5 ds4 ds1
425 * | / \
426 * ds7 ds3 ds2
427 *
428 * Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
429 *
430 * ds7->Drop() yields the tree below:
431 *
432 * ds10
433 * / \
434 * ds9 ds6
435 * | / | \
436 * ds8 ds5 ds4 ds1
437 * / \
438 * ds3 ds2
439 *
440 * Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
441 * becomes its parent's child.
442 *
443 * ds8->Drop() yields the tree below:
444 *
445 * ds10
446 * / \
447 * ds9 ds6
448 * | / | \
449 * ds7 ds5 ds4 ds1
450 * / \
451 * ds3 ds2
452 *
453 * Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and the node's
454 * children become its parent's children.
455 *
456 * When the input tree is
457 *
458 * ds10
459 * / \
460 * ds9 ds6
461 * | |
462 * ds8 ds4
463 * | / \
464 * ds7 ds3 ds2
465 *
466 * ds4->Drop() yields the tree below:
467 *
468 * ds10
469 * / \
470 * ds9 ds6
471 * | / \
472 * ds8 ds3 ds2
473 * |
474 * ds7
475 *
476 * But if ds6 is not an n-ary operator, ds4->Drop() will raise an error because we cannot add the children of an
477 * n-ary operator (ds4) to a unary operator (ds6).
478 *
479 * Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will be
480 * squeezed left.
481 *
482 * Input tree:
483 * ds10
484 * / \
485 * ds9 ds6
486 * | / | \
487 * ds8 ds5 ds4 ds1
488 * | / \
489 * ds7 ds3 ds2
490 *
491 * ds5->Drop() yields the tree below:
492 *
493 * ds10
494 * / \
495 * ds9 ds6
496 * | / \
497 * ds8 ds4 ds1
498 * | / \
499 * ds7 ds3 ds2
500 *
501 * Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
502 * children become its parent's children.
503 *
504 * Input tree:
505 * ds10
506 * / \
507 * ds9 ds6
508 * | / | \
509 * ds8 ds5 ds4 ds1
510 * | |
511 * ds7 ds3
512 *
513 * ds4->Drop() yields the tree below:
514 *
515 * ds10
516 * / \
517 * ds9 ds6
518 * | / | \
519 * ds8 ds5 ds3 ds1
520 * |
521 * ds7
522 *
523 * Case 6: When the node has more than one child and more than one sibling, Drop() will raise an error.
524 * If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it
525 * with a combination of Drop(), InsertChildAt()
526 *
527 * Input tree:
528 * ds10
529 * / \
530 * ds9 ds6
531 * | / | \
532 * ds8 ds5 ds4 ds1
533 * | / \
534 * ds7 ds3 ds2
535 *
536 * If we want to form this tree below:
537 *
538 * ds10
539 * / \
540 * ds9 ds6_____
541 * | / | | \
542 * ds8 ds5 ds3 ds2 ds1
543 * |
544 * ds7
545 *
546 */
Drop()547 Status DatasetNode::Drop() {
548 CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node to drop must not be the root or a node without parent.");
549 CHECK_FAIL_RETURN_UNEXPECTED(!(IsNaryOperator() && parent_->IsUnaryOperator()),
550 "Trying to drop an n-ary operator that is a child of a unary operator");
551 CHECK_FAIL_RETURN_UNEXPECTED(!(children_.size() > 1 && parent_->children_.size() > 1),
552 "This node to drop must not have more than one child and more than one sibling.");
553 if (parent_->children_.size() == 1) {
554 auto my_parent = parent_;
555 // Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
556 // becomes its parent's child.
557 // This is the most common use case.
558 if (children_.size() == 1) {
559 auto child = children_[0];
560 // Move its child to be its parent's child
561 my_parent->children_[0] = child;
562 child->parent_ = my_parent;
563 } else if (children_.empty()) {
564 // Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
565 // Remove this node from its parent's child
566 parent_->children_.clear();
567 } else if (children_.size() > 1) {
568 // Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and
569 // the node's children become its parent's children.
570 // Remove this node from its parent's child
571 my_parent->children_.clear();
572 // Move its child to be its parent's child
573 for (auto &my_child : children_) {
574 my_parent->children_.push_back(my_child);
575 my_child->parent_ = my_parent;
576 }
577 }
578 // And mark itself as an orphan
579 parent_ = nullptr;
580 children_.clear();
581 } else if (children_.empty() && parent_->children_.size() > 1) {
582 // Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will
583 // be squeezed left.
584 auto parent = parent_;
585 // Remove this node from its parent's child
586 parent->children_.erase(std::remove(parent->children_.begin(), parent->children_.end(), shared_from_this()),
587 parent->children_.end()); // removal using "erase remove idiom"
588 // And mark itself as an orphan
589 parent_ = nullptr;
590 children_.clear();
591 } else if (children_.size() == 1 && parent_->children_.size() > 1) {
592 // Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
593 // children become its parent's children.
594 auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
595 CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
596 *itr = children_[0]; // replace this node in its parent's children list with its single child
597 children_[0]->parent_ = parent_; // set its single child's parent ptr to its parent
598 // And mark itself as an orphan
599 parent_ = nullptr;
600 children_.clear();
601 } else {
602 RETURN_STATUS_UNEXPECTED("Internal error: we should not reach here.");
603 }
604 return Status::OK();
605 }
606
607 // In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Accept(IRNodePass * const p,bool * const modified)608 Status DatasetNode::Accept(IRNodePass *const p, bool *const modified) {
609 // This method will only be called if its derived class does not implement one.
610 return p->Visit(shared_from_this(), modified);
611 }
612
613 // In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
614 // after all child nodes are visited.
AcceptAfter(IRNodePass * const p,bool * const modified)615 Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
616 // This method will only be called if its derived class does not implement one.
617 return p->VisitAfter(shared_from_this(), modified);
618 }
619
GetShardId(int32_t * const shard_id)620 Status DatasetNode::GetShardId(int32_t *const shard_id) {
621 if (children_.size() == 1) {
622 // Get shard id from the child node
623 return children_[0]->GetShardId(shard_id);
624 } else if (children_.size() > 1) {
625 // It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case.
626 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
627 // always be in front of the child_ structure, so we get the dataset size from the last child.
628 return children_.back()->GetShardId(shard_id);
629 } else {
630 LOG_AND_RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
631 }
632 }
633
634 // Gets the dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)635 Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
636 int64_t *dataset_size) {
637 if (dataset_size_ > 0) {
638 *dataset_size = dataset_size_;
639 return Status::OK();
640 }
641 if (!IsSizeDefined() && size_getter != nullptr) {
642 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
643 dataset_size_ = *dataset_size;
644 return Status::OK();
645 }
646 if (children_.size() == 1) {
647 return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size);
648 } else if (children_.size() > 1) {
649 // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
650 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
651 // always be in front of the child_ structure, so we get the dataset size from the last child.
652 return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size);
653 } else {
654 RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
655 }
656 }
ValidateParams()657 Status DatasetNode::ValidateParams() {
658 int32_t num_threads = GlobalContext::config_manager()->num_cpu_threads();
659 // in case std::thread::hardware_concurrency returns 0, use an artificial upper limit
660 num_threads = num_threads > 0 ? num_threads : std::numeric_limits<uint16_t>::max();
661 CHECK_FAIL_RETURN_UNEXPECTED(
662 num_workers_ > 0 && num_workers_ <= num_threads,
663 Name() + "'s num_workers=" + std::to_string(num_workers_) +
664 ", this value is not within the required range of [1, cpu_thread_cnt=" + std::to_string(num_threads) + "].");
665 return Status::OK();
666 }
667
to_json(nlohmann::json * out_json)668 Status DatasetNode::to_json(nlohmann::json *out_json) {
669 nlohmann::json args;
670 args["num_parallel_workers"] = num_workers_;
671 args["connector_queue_size"] = connector_que_size_;
672 *out_json = args;
673 return Status::OK();
674 }
675
Accept(IRNodePass * const p,bool * const modified)676 Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
677 return p->Visit(shared_from_base<MappableSourceNode>(), modified);
678 }
679
Accept(IRNodePass * const p,bool * const modified)680 Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
681 return p->Visit(shared_from_base<NonMappableSourceNode>(), modified);
682 }
683
684 } // namespace dataset
685 } // namespace mindspore
686