1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/opt/post/generator_node_pass.h"
18 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
19
20 namespace mindspore {
21 namespace dataset {
22
GeneratorNodePass()23 GeneratorNodePass::GeneratorNodePass() : repeat_ancestors_({}) {}
24 /*
25 * A diagram shows how the code work:
26 * With the tree below as an input
27 *
28 * EpochCtrl(-1)
29 * / \
30 * Repeat1 \
31 * / Repeat3
32 * .. \
33 * / Generator2
34 * Repeat2 Add: Gen2-Rep3
35 * /
36 * Generator1
37 * Add: Gen1-Rep2
38 *
39 * The sequence of the DFS walk of the tree looks like this:
40 * 1) Visit(EpochCtrl): push EpochCtrl, repeat_ancestor_ = { EpochCtrl }
41 * 2) Visit(Repeat1): push Repeat1, repeat_ancestors_ = { EpochCtrl, Repeat1 }
42 * 3) Visit(Repeat2): push Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1, Repeat2 }
43 * 4) Visit(Generator1): record Repeat2 as its ancestor
44 * record Repeat1 as Repeat2's ancestor
45 * record EpochCtrl as Repeat1's ancestor
46 * 5) VisitAfter(Repeat2): pop Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1 }
47 * 6) VisitAfter(Repeat1): pop Repeat1, repeat_ancestors_ = { EpochCtrl }
48 * 7) Visit(Repeat3): push Repeat3, repeat_ancestors_ = { EpochCtrl, Repeat3 }
49 * 8) Visit(Generator2): record Repeat3 as its ancestors
50 * record EpochCtrl as Repeat3's ancestor
51 * 9) VisitAfter(Repeat3): pop Repeat3, repeat_ancestors_ = { EpochCtrl }
52 * 10) VisitAfter(EpochCtrl): don't care. We could pop EpochCtrl.
53 */
54
Visit(std::shared_ptr<EpochCtrlNode> node,bool * const modified)55 Status GeneratorNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
56 // Add this EpochCtrl node as an ancestor of its descendant
57 repeat_ancestors_.push_back(node);
58 return Status::OK();
59 }
60
Visit(std::shared_ptr<RepeatNode> node,bool * const modified)61 Status GeneratorNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
62 // Add this Repeat node as an ancestor of its descendant
63 repeat_ancestors_.push_back(node);
64 return Status::OK();
65 }
66
Visit(std::shared_ptr<GeneratorNode> node,bool * const modified)67 Status GeneratorNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
68 // Form a reset relationship with the immediate Repeat/EpochCtrl ancestor node of this leaf Generator Node
69 // only when any of its ancestors is an infinite repeat.
70 if (repeat_ancestors_.size() > 0) {
71 bool infinite_repeat = false;
72 for (auto &repeat_ancestor : repeat_ancestors_) {
73 if (repeat_ancestor->Count() < 0) {
74 infinite_repeat = true;
75 break;
76 }
77 }
78 if (infinite_repeat) {
79 // Form a pair-wise relationship between this leaf Generator node and its immediate Repeat/EpochCtrl
80 // ancestor node, and between the next adjacent pairs in the vector. For example,
81 // if we have GeneratorNode -> Repeat1 -> Repeat2 -> EpochCtrl(-1), the pair-wise relationships are:
82 // (GeneratorNode, Repeat1), (Repeat1, Repeat2), and (Repeat2, EpochCtrl)
83 for (auto i = repeat_ancestors_.size() - 1; i > 0; --i) {
84 auto ancestor = repeat_ancestors_[i - 1];
85 RETURN_IF_NOT_OK(repeat_ancestors_[i]->AddResetAncestor(ancestor));
86 }
87 RETURN_IF_NOT_OK(node->AddResetAncestor(repeat_ancestors_.back()));
88 }
89 }
90 return Status::OK();
91 }
92
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)93 Status GeneratorNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
94 // When we backtrack from the same Repeat node, we pop it out from the list of ancestors.
95 repeat_ancestors_.pop_back();
96 return Status::OK();
97 }
98
VisitAfter(std::shared_ptr<EpochCtrlNode> node,bool * const modified)99 Status GeneratorNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
100 // As EpochCtrl node is a terminal node, the process stops here.
101 // Popping it back out of the reset ancestors is unnecessary.
102 // This function becomes a no-op function and can be deleted completely.
103 return Status::OK();
104 }
105
106 } // namespace dataset
107 } // namespace mindspore
108