1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
18
19 #include <algorithm>
20 #include <limits>
21 #include <memory>
22 #include <set>
23
24 #include "minddata/dataset/engine/opt/pass.h"
25 #include "minddata/dataset/util/random.h"
26 #include "minddata/dataset/util/status.h"
27
28 namespace mindspore {
29 namespace dataset {
30
31 // Helper function to compute a default shuffle size
ComputeShuffleSize(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int64_t * shuffle_size)32 Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
33 int64_t *shuffle_size) {
34 RETURN_UNEXPECTED_IF_NULL(shuffle_size);
35 const int64_t average_files_multiplier = 4;
36 const int64_t shuffle_max = 10000;
37 int64_t avg_rows_per_file = 0;
38
39 // Adjust the num rows per shard if sharding was given
40 if (num_devices > 0) {
41 if (num_rows % num_devices == 0) {
42 num_rows = num_rows / num_devices;
43 } else {
44 num_rows = (num_rows / num_devices) + 1;
45 }
46 }
47
48 // Cap based on total rows directive. Some ops do not have this and give value of 0.
49 if (total_rows > 0) {
50 num_rows = std::min(num_rows, total_rows);
51 }
52
53 // get the average per file
54 CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0.");
55 avg_rows_per_file = num_rows / num_files;
56
57 *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
58 return Status::OK();
59 }
60
61 // Helper function to inject a shuffle operator over top of current operator being built
AddShuffleOp(int64_t num_files,int64_t num_devices,int64_t num_rows,int64_t total_rows,int32_t connector_que_size,std::shared_ptr<DatasetOp> * shuffle_op)62 Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
63 int32_t connector_que_size, std::shared_ptr<DatasetOp> *shuffle_op) {
64 RETURN_UNEXPECTED_IF_NULL(shuffle_op);
65 int64_t shuffle_size = 0;
66 RETURN_IF_NOT_OK(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
67 MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
68 // Add the shuffle op
69 *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true);
70 return Status::OK();
71 }
72
73 // Helper function to validate dataset directory parameter
ValidateDatasetDirParam(const std::string & dataset_name,std::string dataset_dir)74 Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
75 if (dataset_dir.empty()) {
76 std::string err_msg = dataset_name + ": dataset_dir is not specified.";
77 MS_LOG(ERROR) << err_msg;
78 RETURN_STATUS_SYNTAX_ERROR(err_msg);
79 }
80
81 std::string real_path;
82 RETURN_IF_NOT_OK(Path::RealPath(dataset_dir, real_path));
83 Path dir(dataset_dir);
84 if (!dir.IsDirectory()) {
85 std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
86 MS_LOG(ERROR) << err_msg;
87 RETURN_STATUS_SYNTAX_ERROR(err_msg);
88 }
89
90 if (access(dataset_dir.c_str(), R_OK) == -1) {
91 std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
92 MS_LOG(ERROR) << err_msg;
93 RETURN_STATUS_SYNTAX_ERROR(err_msg);
94 }
95
96 return Status::OK();
97 }
98
99 // Helper function to validate dataset files parameter
ValidateDatasetFilesParam(const std::string & dataset_name,const std::vector<std::string> & dataset_files)100 Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
101 if (dataset_files.empty()) {
102 std::string err_msg = dataset_name + ": dataset_files is not specified.";
103 MS_LOG(ERROR) << err_msg;
104 RETURN_STATUS_SYNTAX_ERROR(err_msg);
105 }
106
107 for (auto f : dataset_files) {
108 Path dataset_file(f);
109 if (!dataset_file.Exists()) {
110 std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist.";
111 MS_LOG(ERROR) << err_msg;
112
113 RETURN_STATUS_SYNTAX_ERROR(err_msg);
114 }
115 if (access(dataset_file.ToString().c_str(), R_OK) == -1) {
116 std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
117 MS_LOG(ERROR) << err_msg;
118 RETURN_STATUS_SYNTAX_ERROR(err_msg);
119 }
120 }
121
122 return Status::OK();
123 }
124
125 // Helper function to validate dataset num_shards and shard_id parameters
ValidateDatasetShardParams(const std::string & dataset_name,int32_t num_shards,int32_t shard_id)126 Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
127 if (num_shards <= 0) {
128 std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
129 MS_LOG(ERROR) << err_msg;
130 RETURN_STATUS_SYNTAX_ERROR(err_msg);
131 }
132
133 if (shard_id < 0 || shard_id >= num_shards) {
134 // num_shards
135 std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
136 ", num_shards: " + std::to_string(num_shards);
137 MS_LOG(ERROR) << err_msg;
138 RETURN_STATUS_SYNTAX_ERROR(err_msg);
139 }
140
141 return Status::OK();
142 }
143
144 // Helper function to validate dataset sampler parameter
ValidateDatasetSampler(const std::string & dataset_name,const std::shared_ptr<SamplerObj> & sampler)145 Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
146 if (sampler == nullptr) {
147 std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
148 MS_LOG(ERROR) << err_msg;
149 RETURN_STATUS_SYNTAX_ERROR(err_msg);
150 }
151 RETURN_IF_NOT_OK(sampler->ValidateParams());
152 return Status::OK();
153 }
154
ValidateStringValue(const std::string & dataset_name,const std::string & str,const std::unordered_set<std::string> & valid_strings)155 Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
156 const std::unordered_set<std::string> &valid_strings) {
157 if (valid_strings.find(str) == valid_strings.end()) {
158 std::string init;
159 std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init,
160 [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
161 std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
162 MS_LOG(ERROR) << err_msg;
163 RETURN_STATUS_SYNTAX_ERROR(err_msg);
164 }
165 return Status::OK();
166 }
167
168 // Helper function to validate dataset input/output column parameter
ValidateDatasetColumnParam(const std::string & dataset_name,const std::string & column_param,const std::vector<std::string> & columns)169 Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
170 const std::vector<std::string> &columns) {
171 if (columns.empty()) {
172 std::string err_msg = dataset_name + ":" + column_param + " should not be empty string";
173 MS_LOG(ERROR) << err_msg;
174 RETURN_STATUS_SYNTAX_ERROR(err_msg);
175 }
176 for (uint32_t i = 0; i < columns.size(); ++i) {
177 if (columns[i].empty()) {
178 std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty";
179 MS_LOG(ERROR) << err_msg;
180 RETURN_STATUS_SYNTAX_ERROR(err_msg);
181 }
182 }
183 std::set<std::string> columns_set;
184 for (auto &column_name : columns) {
185 auto result = columns_set.insert(column_name);
186 if (result.second == false) {
187 std::string err_msg = dataset_name + ":" + column_param +
188 ": Invalid parameter, duplicate column names are not allowed: " + *result.first;
189 MS_LOG(ERROR) << err_msg;
190 RETURN_STATUS_SYNTAX_ERROR(err_msg);
191 }
192 }
193 return Status::OK();
194 }
195
SelectSampler(int64_t num_samples,bool shuffle,int32_t num_shards,int32_t shard_id)196 std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
197 if (shuffle) {
198 if (num_shards > 1) {
199 // If shuffle enabled, sharding enabled, use distributed random sampler
200 return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
201 }
202 // If shuffle enabled, sharding disabled, use random sampler
203 return RandomSampler(num_samples >= 0, num_samples).Parse();
204 }
205 if (num_shards > 1) {
206 // If shuffle disabled, sharding enabled, use distributed sequential sampler
207 return DistributedSampler(num_shards, shard_id, shuffle, num_samples).Parse();
208 }
209 // If shuffle disabled, sharding disabled, use sequential sampler
210 return SequentialSampler(0, num_samples).Parse();
211 }
212
213 // Constructor to initialize the cache
DatasetNode(const std::shared_ptr<DatasetCache> & dataset_cache)214 DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }
215
SetNumWorkers(int32_t num_workers)216 std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
217 num_workers_ = num_workers;
218 return shared_from_this();
219 }
220
SetDatasetCache(const std::shared_ptr<DatasetCache> & cache)221 std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) {
222 cache_ = cache;
223 return shared_from_this();
224 }
225
DatasetNode()226 DatasetNode::DatasetNode()
227 : cache_(nullptr),
228 parent_(nullptr),
229 children_({}),
230 dataset_size_(-1),
231 mappable_(kNotADataSource),
232 nary_op_(false),
233 descendant_of_cache_(false),
234 total_repeats_(-1),
235 num_epochs_(1) {
236 // Fetch some default value from config manager
237 std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
238 num_workers_ = cfg->num_parallel_workers();
239 connector_que_size_ = cfg->op_connector_size();
240 worker_connector_size_ = cfg->worker_connector_size();
241 }
242
PrintColumns(const std::vector<std::string> & columns) const243 std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
244 std::string me;
245 if (columns.empty()) {
246 me = "<nil>";
247 } else {
248 me = "[";
249 auto i = 0;
250 for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
251 me += *it;
252 if (i < columns.size() - 1) {
253 me += ", ";
254 } else {
255 me += "]";
256 }
257 }
258 }
259 return me;
260 }
261
PrintTree(std::ostream & out) const262 void DatasetNode::PrintTree(std::ostream &out) const {
263 int level = 0;
264 PrintNode(out, &level);
265 }
266
PrintNode(std::ostream & out,int * level) const267 void DatasetNode::PrintNode(std::ostream &out, int *level) const {
268 const std::string prefix = "+-";
269 const std::string indent = "| ";
270 out << prefix;
271 Print(out);
272 for (const auto &c : this->Children()) {
273 out << '\n';
274 ++(*level);
275 for (auto i = 0; i < *level; i++) {
276 out << indent;
277 }
278 c->PrintNode(out, level);
279 --(*level);
280 }
281 }
282
283 // Add a node as a child, node's parent needs to be empty
284 // This function will allow child to be a nullptr, in which case it will simply skip.
285 // This function is used only when building IR node one by one from parsing the user code.
286 // During the parsing, we allow a node to have more than one parent, possibly forming a graph.
287 // It does not maintain the parent_ attribute of the node, which enforces a single parent and a tree structure.
AddChild(std::shared_ptr<DatasetNode> child)288 void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
289 if (child != nullptr) {
290 children_.push_back(child);
291 }
292 }
293
294 /*
295 * AppendChild(<node>) appending <node> as the last child of this node. The new node must have no parent.
296 *
297 * Input tree:
298 * ds4
299 * / \
300 * ds3 ds2
301 * |
302 * ds1
303 *
304 * ds4->AppendChild(ds6) yields this tree
305 *
306 * _ ds4 _
307 * / | \
308 * ds3 ds2 ds6
309 * |
310 * ds1
311 *
312 */
AppendChild(std::shared_ptr<DatasetNode> child)313 Status DatasetNode::AppendChild(std::shared_ptr<DatasetNode> child) {
314 CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
315 CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
316 "This node must be a unary operator with no child or an n-ary operator");
317 children_.push_back(child);
318 child->parent_ = this;
319 return Status::OK();
320 }
321
322 /*
323 * InsertChildAt(<pos>, <node>) inserts the <node> to be at the <pos> index of the vector of its child nodes.
324 * As in the convention of C++, <pos> starts at position 0.
325 * If the <pos> is a negative number or larger than the size of the vector minus one, an error is raised.
326 */
InsertChildAt(int32_t pos,std::shared_ptr<DatasetNode> child)327 Status DatasetNode::InsertChildAt(int32_t pos, std::shared_ptr<DatasetNode> child) {
328 CHECK_FAIL_RETURN_UNEXPECTED(pos > -1 && pos <= children_.size(), "Position must in the range of [0, size]");
329 CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(child), "Node to append must be an orphan node.");
330 CHECK_FAIL_RETURN_UNEXPECTED((IsUnaryOperator() && Children().empty()) || IsNaryOperator(),
331 "This node must be a unary operator with no child or an n-ary operator");
332 children_.insert(children_.begin() + pos, child);
333 child->parent_ = this;
334 return Status::OK();
335 }
336
337 /*
338 * Insert the input <node> above this node
339 * Input tree:
340 * ds4
341 * / \
342 * ds3 ds2
343 * |
344 * ds1
345 *
346 * Case 1: If we want to insert a new node ds5 between ds4 and ds3, use
347 * ds3->InsertAbove(ds5)
348 *
349 * ds4
350 * / \
351 * ds5 ds2
352 * |
353 * ds3
354 * |
355 * ds1
356 *
357 * Case 2: Likewise, ds2->InsertAbove(ds6) yields
358 *
359 * ds4
360 * / \
361 * ds3 ds6
362 * | |
363 * ds1 ds2
364 *
365 * Case 3: We can insert a new node between ds3 and ds1 by ds1->InsertAbove(ds7)
366 *
367 * ds4
368 * / \
369 * ds3 ds2
370 * |
371 * ds7
372 * |
373 * ds1
374 *
375 * InsertAbove() cannot use on the root node of a tree.
376 */
InsertAbove(std::shared_ptr<DatasetNode> node)377 Status DatasetNode::InsertAbove(std::shared_ptr<DatasetNode> node) {
378 CHECK_FAIL_RETURN_UNEXPECTED(IsOrphanNode(node), "Node to insert must be an orphan node.");
379 CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node must not be the root or a node without parent.");
380 auto parent = parent_;
381
382 // The following fields of these three nodes are changed in this function:
383 // 1. parent->children_
384 // 2. node->parent_ and node->children_
385 // 3. this->parent_
386 auto current_node_itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
387 *current_node_itr = node; // replace me in my parent's children list with the newly inserted node
388 node->parent_ = parent; // set the newly inserted node's parent ptr to my parent
389 node->children_.push_back(shared_from_this()); // add myself to the newly inserted node's children list
390 parent_ = node.get(); // set my parent ptr to the newly inserted node
391
392 return Status::OK();
393 }
394
395 /*
396 * Drop() detaches this node from the tree it is in. Calling Drop() from a standalone node is a no-op.
397 *
398 * Input tree:
399 * ds10
400 * / \
401 * ds9 ds6
402 * | / | \
403 * ds8 ds5 ds4 ds1
404 * | / \
405 * ds7 ds3 ds2
406 *
407 * Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
408 *
409 * ds7->Drop() yields the tree below:
410 *
411 * ds10
412 * / \
413 * ds9 ds6
414 * | / | \
415 * ds8 ds5 ds4 ds1
416 * / \
417 * ds3 ds2
418 *
419 * Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
420 * becomes its parent's child.
421 *
422 * ds8->Drop() yields the tree below:
423 *
424 * ds10
425 * / \
426 * ds9 ds6
427 * | / | \
428 * ds7 ds5 ds4 ds1
429 * / \
430 * ds3 ds2
431 *
432 * Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and the node's
433 * children become its parent's children.
434 *
435 * When the input tree is
436 *
437 * ds10
438 * / \
439 * ds9 ds6
440 * | |
441 * ds8 ds4
442 * | / \
443 * ds7 ds3 ds2
444 *
445 * ds4->Drop() yields the tree below:
446 *
447 * ds10
448 * / \
449 * ds9 ds6
450 * | / \
451 * ds8 ds3 ds2
452 * |
453 * ds7
454 *
455 * But if ds6 is not an n-ary operator, ds4->Drop() will raise an error because we cannot add the children of an
456 * n-ary operator (ds4) to a unary operator (ds6).
457 *
458 * Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will be
459 * squeezed left.
460 *
461 * Input tree:
462 * ds10
463 * / \
464 * ds9 ds6
465 * | / | \
466 * ds8 ds5 ds4 ds1
467 * | / \
468 * ds7 ds3 ds2
469 *
470 * ds5->Drop() yields the tree below:
471 *
472 * ds10
473 * / \
474 * ds9 ds6
475 * | / \
476 * ds8 ds4 ds1
477 * | / \
478 * ds7 ds3 ds2
479 *
480 * Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
481 * children become its parent's children.
482 *
483 * Input tree:
484 * ds10
485 * / \
486 * ds9 ds6
487 * | / | \
488 * ds8 ds5 ds4 ds1
489 * | |
490 * ds7 ds3
491 *
492 * ds4->Drop() yields the tree below:
493 *
494 * ds10
495 * / \
496 * ds9 ds6
497 * | / | \
498 * ds8 ds5 ds3 ds1
499 * |
500 * ds7
501 *
502 * Case 6: When the node has more than one child and more than one sibling, Drop() will raise an error.
503 * If we want to drop ds4 from the input tree, ds4->Drop() will not work. We will have to do it
504 * with a combination of Drop(), InsertChildAt()
505 *
506 * Input tree:
507 * ds10
508 * / \
509 * ds9 ds6
510 * | / | \
511 * ds8 ds5 ds4 ds1
512 * | / \
513 * ds7 ds3 ds2
514 *
515 * If we want to form this tree below:
516 *
517 * ds10
518 * / \
519 * ds9 ds6_____
520 * | / | | \
521 * ds8 ds5 ds3 ds2 ds1
522 * |
523 * ds7
524 *
525 */
Drop()526 Status DatasetNode::Drop() {
527 CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "This node to drop must not be the root or a node without parent.");
528 CHECK_FAIL_RETURN_UNEXPECTED(!(IsNaryOperator() && parent_->IsUnaryOperator()),
529 "Trying to drop an n-ary operator that is a child of a unary operator");
530 CHECK_FAIL_RETURN_UNEXPECTED(!(children_.size() > 1 && parent_->children_.size() > 1),
531 "This node to drop must not have more than one child and more than one sibling.");
532 if (parent_->children_.size() == 1) {
533 auto my_parent = parent_;
534 // Case 2: When the node has one child and no sibling, Drop() detaches the node from its tree and the node's child
535 // becomes its parent's child.
536 // This is the most common use case.
537 if (children_.size() == 1) {
538 auto child = children_[0];
539 // Move its child to be its parent's child
540 my_parent->children_[0] = child;
541 child->parent_ = my_parent;
542 } else if (children_.empty()) {
543 // Case 1: When the node has no child and no sibling, Drop() detaches the node from its tree.
544 // Remove this node from its parent's child
545 parent_->children_.clear();
546 } else if (children_.size() > 1) {
547 // Case 3: When the node has more than one child and no sibling, Drop() detaches the node from its tree and
548 // the node's children become its parent's children.
549 // Remove this node from its parent's child
550 my_parent->children_.clear();
551 // Move its child to be its parent's child
552 for (auto &my_child : children_) {
553 my_parent->children_.push_back(my_child);
554 my_child->parent_ = my_parent;
555 }
556 }
557 // And mark itself as an orphan
558 parent_ = nullptr;
559 children_.clear();
560 } else if (children_.empty() && parent_->children_.size() > 1) {
561 // Case 4: When the node has no child but has siblings, Drop() detaches the node from its tree and its siblings will
562 // be squeezed left.
563 auto parent = parent_;
564 // Remove this node from its parent's child
565 parent->children_.erase(std::remove(parent->children_.begin(), parent->children_.end(), shared_from_this()),
566 parent->children_.end()); // removal using "erase remove idiom"
567 // And mark itself as an orphan
568 parent_ = nullptr;
569 children_.clear();
570 } else if (children_.size() == 1 && parent_->children_.size() > 1) {
571 // Case 5: When the node has only one child but has siblings, Drop() detaches the node from its tree and the node's
572 // children become its parent's children.
573 auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
574 CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
575 *itr = children_[0]; // replace this node in its parent's children list with its single child
576 children_[0]->parent_ = parent_; // set its single child's parent ptr to its parent
577 // And mark itself as an orphan
578 parent_ = nullptr;
579 children_.clear();
580 } else {
581 RETURN_STATUS_UNEXPECTED("Internal error: we should not reach here.");
582 }
583 return Status::OK();
584 }
585
586 // In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Accept(IRNodePass * const p,bool * const modified)587 Status DatasetNode::Accept(IRNodePass *const p, bool *const modified) {
588 // This method will only be called if its derived class does not implement one.
589 return p->Visit(shared_from_this(), modified);
590 }
591
592 // In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
593 // after all child nodes are visited.
AcceptAfter(IRNodePass * const p,bool * const modified)594 Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
595 // This method will only be called if its derived class does not implement one.
596 return p->VisitAfter(shared_from_this(), modified);
597 }
598
GetShardId(int32_t * const shard_id)599 Status DatasetNode::GetShardId(int32_t *const shard_id) {
600 if (children_.size() == 1) {
601 // Get shard id from the child node
602 return children_[0]->GetShardId(shard_id);
603 } else if (children_.size() > 1) {
604 // It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case.
605 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
606 // always be in front of the child_ structure, so we get the dataset size from the last child.
607 return children_.back()->GetShardId(shard_id);
608 } else {
609 RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
610 }
611 }
612
613 // Gets the dataset size
GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> & size_getter,bool estimate,int64_t * dataset_size)614 Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
615 int64_t *dataset_size) {
616 if (dataset_size_ > 0) {
617 *dataset_size = dataset_size_;
618 return Status::OK();
619 }
620 if (!IsSizeDefined()) {
621 RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
622 dataset_size_ = *dataset_size;
623 return Status::OK();
624 }
625 if (children_.size() == 1) {
626 return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size);
627 } else if (children_.size() > 1) {
628 // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
629 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
630 // always be in front of the child_ structure, so we get the dataset size from the last child.
631 return children_.back()->GetDatasetSize(size_getter, estimate, dataset_size);
632 } else {
633 RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
634 }
635 }
ValidateParams()636 Status DatasetNode::ValidateParams() {
637 int32_t num_threads = GlobalContext::config_manager()->num_cpu_threads();
638 // in case std::thread::hardware_concurrency returns 0, use an artificial upper limit
639 num_threads = num_threads > 0 ? num_threads : std::numeric_limits<uint16_t>::max();
640 CHECK_FAIL_RETURN_UNEXPECTED(
641 num_workers_ > 0 && num_workers_ <= num_threads,
642 Name() + "'s num_workers=" + std::to_string(num_workers_) +
643 ", this value is not within the required range of [1, cpu_thread_cnt=" + std::to_string(num_threads) + "].");
644 return Status::OK();
645 }
646
to_json(nlohmann::json * out_json)647 Status DatasetNode::to_json(nlohmann::json *out_json) {
648 nlohmann::json args;
649 args["num_parallel_workers"] = num_workers_;
650 *out_json = args;
651 return Status::OK();
652 }
653
Accept(IRNodePass * const p,bool * const modified)654 Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
655 return p->Visit(shared_from_base<MappableSourceNode>(), modified);
656 }
657
Accept(IRNodePass * const p,bool * const modified)658 Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
659 return p->Visit(shared_from_base<NonMappableSourceNode>(), modified);
660 }
661
662 } // namespace dataset
663 } // namespace mindspore
664