• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/opt/post/generator_node_pass.h"
18 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
19 
20 namespace mindspore {
21 namespace dataset {
22 
GeneratorNodePass()23 GeneratorNodePass::GeneratorNodePass() : repeat_ancestors_({}) {}
24 /*
25  * A diagram shows how the code work:
26  *   With the tree below as an input
27  *
28  *             EpochCtrl(-1)
29  *              /    \
30  *          Repeat1   \
31  *           /      Repeat3
32  *          ..          \
33  *         /          Generator2
34  *      Repeat2       Add: Gen2-Rep3
35  *       /
36  *     Generator1
37  *     Add: Gen1-Rep2
38  *
39  * The sequence of the DFS walk of the tree looks like this:
40  *   1) Visit(EpochCtrl): push EpochCtrl, repeat_ancestor_ = { EpochCtrl }
41  *   2) Visit(Repeat1): push Repeat1, repeat_ancestors_ = { EpochCtrl, Repeat1 }
42  *   3) Visit(Repeat2): push Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1, Repeat2 }
43  *   4) Visit(Generator1): record Repeat2 as its ancestor
44  *                         record Repeat1 as Repeat2's ancestor
45  *                         record EpochCtrl as Repeat1's ancestor
46  *   5) VisitAfter(Repeat2): pop Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1 }
47  *   6) VisitAfter(Repeat1): pop Repeat1, repeat_ancestors_ = { EpochCtrl }
48  *   7) Visit(Repeat3): push Repeat3, repeat_ancestors_ = { EpochCtrl, Repeat3 }
49  *   8) Visit(Generator2): record Repeat3 as its ancestors
50  *                         record EpochCtrl as Repeat3's ancestor
51  *   9) VisitAfter(Repeat3): pop Repeat3, repeat_ancestors_ = { EpochCtrl }
52  *   10) VisitAfter(EpochCtrl): don't care. We could pop EpochCtrl.
53  */
54 
Visit(std::shared_ptr<EpochCtrlNode> node,bool * const modified)55 Status GeneratorNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
56   // Add this EpochCtrl node as an ancestor of its descendant
57   repeat_ancestors_.push_back(node);
58   return Status::OK();
59 }
60 
Visit(std::shared_ptr<RepeatNode> node,bool * const modified)61 Status GeneratorNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
62   // Add this Repeat node as an ancestor of its descendant
63   repeat_ancestors_.push_back(node);
64   return Status::OK();
65 }
66 
Visit(std::shared_ptr<GeneratorNode> node,bool * const modified)67 Status GeneratorNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
68   // Form a reset relationship with the immediate Repeat/EpochCtrl ancestor node of this leaf Generator Node
69   // only when any of its ancestors is an infinite repeat.
70   if (repeat_ancestors_.size() > 0) {
71     bool infinite_repeat = false;
72     for (auto &repeat_ancestor : repeat_ancestors_) {
73       if (repeat_ancestor->Count() < 0) {
74         infinite_repeat = true;
75         break;
76       }
77     }
78     if (infinite_repeat) {
79       // Form a pair-wise relationship between this leaf Generator node and its immediate Repeat/EpochCtrl
80       // ancestor node, and between the next adjacent pairs in the vector. For example,
81       // if we have GeneratorNode -> Repeat1 -> Repeat2 -> EpochCtrl(-1), the pair-wise relationships are:
82       // (GeneratorNode, Repeat1), (Repeat1, Repeat2), and (Repeat2, EpochCtrl)
83       for (auto i = repeat_ancestors_.size() - 1; i > 0; --i) {
84         auto ancestor = repeat_ancestors_[i - 1];
85         RETURN_IF_NOT_OK(repeat_ancestors_[i]->AddResetAncestor(ancestor));
86       }
87       RETURN_IF_NOT_OK(node->AddResetAncestor(repeat_ancestors_.back()));
88     }
89   }
90   return Status::OK();
91 }
92 
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)93 Status GeneratorNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
94   // When we backtrack from the same Repeat node, we pop it out from the list of ancestors.
95   repeat_ancestors_.pop_back();
96   return Status::OK();
97 }
98 
VisitAfter(std::shared_ptr<EpochCtrlNode> node,bool * const modified)99 Status GeneratorNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
100   // As EpochCtrl node is a terminal node, the process stops here.
101   // Popping it back out of the reset ancestors is unnecessary.
102   // This function becomes a no-op function and can be deleted completely.
103   return Status::OK();
104 }
105 
106 }  // namespace dataset
107 }  // namespace mindspore
108