1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/opt/post/repeat_pass.h"
18
19 #include <memory>
20
21 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
22 #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
23 #include "minddata/dataset/engine/ir/datasetops/cache_node.h"
24 #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
25 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
26 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
27
28 namespace mindspore {
29 namespace dataset {
30
RepeatPass()31 RepeatPass::RepeatPass()
32 : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {}
33
34 // Identifies the subtree below this node as being in a repeated path of the tree.
Visit(std::shared_ptr<RepeatNode> node,bool * const modified)35 Status RepeatPass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
36 RETURN_UNEXPECTED_IF_NULL(node);
37 RETURN_UNEXPECTED_IF_NULL(modified);
38 // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
39 // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
40 if (node->Count() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
41 num_repeats_ = -num_repeats_;
42 }
43 // This RepeatOp and its descendent nodes should be repeated for another num_repeats() times.
44 //
45 // Consider this example:
46 // tfreader --> map --> repeat(2) --> epoch ctrl(3)
47 // num_repeats_ is originally 3, after this repeat(2), num_repeats_ becomes 6 (2*3),
48 // meaning repeat op should be set to read 6 times (2*3), do does map op and tfreader op.
49 //
50 // Another example:
51 // tfreader --> repeat1(3) --> map --> repeat2(2) --> epoch ctrl(4)
52 // num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4),
53 // meaning repeat2 and map op should be set to read 8 times (2*4).
54 // Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times.
55 num_repeats_ *= node->Count();
56 return Status::OK();
57 }
58
59 // Identifies the subtree below this node as being in a repeated path of the tree.
Visit(std::shared_ptr<EpochCtrlNode> node,bool * const modified)60 Status RepeatPass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
61 RETURN_UNEXPECTED_IF_NULL(node);
62 RETURN_UNEXPECTED_IF_NULL(modified);
63 // Get the total number of epochs from the EpochCtrlOp parameter
64 num_epochs_ = node->Count();
65 // Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
66 // For example: tfreader --> epoch ctrl(3)
67 // num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3),
68 // meaning epoch ctrl op should be set to read 3 times (1*3), so does tfreader op.
69 num_repeats_ *= num_epochs_;
70 return Status::OK();
71 }
72
73 #ifndef ENABLE_ANDROID
74 // Identifies the subtree below this node as being in a cache merge path
Visit(std::shared_ptr<CacheMergeNode> node,bool * const modified)75 Status RepeatPass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
76 RETURN_UNEXPECTED_IF_NULL(node);
77 RETURN_UNEXPECTED_IF_NULL(modified);
78 // Turn on the flag that we're under a merge op
79 is_merge_ = true;
80 return Status::OK();
81 }
82
83 // Identifies the subtree below this node as being cached
Visit(std::shared_ptr<CacheNode> node,bool * const modified)84 Status RepeatPass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
85 RETURN_UNEXPECTED_IF_NULL(node);
86 RETURN_UNEXPECTED_IF_NULL(modified);
87 // Turn on the flag that we're under a merge op
88 is_cached_ = true;
89 return Status::OK();
90 }
91 #endif
92
93 // Hooks up any identified eoe nodes under this repeat.
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)94 Status RepeatPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
95 RETURN_UNEXPECTED_IF_NULL(node);
96 RETURN_UNEXPECTED_IF_NULL(modified);
97 // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
98 // and set its total repeats. It is important that the op is removed from the save area,
99 // because the merge op above us may also take action on it later for a different case when
100 // there is no repeat in the merge leg.
101 if (is_merge_ && cache_lookup_) {
102 cache_lookup_->SetTotalRepeats(num_repeats_);
103 cache_lookup_->SetNumEpochs(num_epochs_);
104 cache_lookup_.reset();
105 }
106
107 if (is_cached_) {
108 AddToCachedNodeStack(node);
109 }
110 node->SetTotalRepeats(num_repeats_);
111 node->SetNumEpochs(num_epochs_);
112 // We finish the walk of this RepeatOp's descendent nodes.
113 // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
114 // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
115 // so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
116 CHECK_FAIL_RETURN_UNEXPECTED(node->Count() != 0, "Invalid data, the number of node can't be 0.");
117 num_repeats_ /= node->Count();
118 return Status::OK();
119 }
120
121 // Hooks up any identified eoe nodes under this repeat.
VisitAfter(std::shared_ptr<EpochCtrlNode> node,bool * const modified)122 Status RepeatPass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
123 RETURN_UNEXPECTED_IF_NULL(node);
124 RETURN_UNEXPECTED_IF_NULL(modified);
125 CHECK_FAIL_RETURN_UNEXPECTED(node->Count() != 0, "Invalid data, the number of node can't be 0.");
126 node->SetTotalRepeats(num_repeats_);
127 node->SetNumEpochs(num_epochs_);
128 // We finish the walk of this EpochCtrl's descendent nodes.
129 num_repeats_ /= node->Count();
130 return Status::OK();
131 }
132
133 // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
134 // for use with a controlling repeat above it.
VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)135 Status RepeatPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
136 RETURN_UNEXPECTED_IF_NULL(node);
137 RETURN_UNEXPECTED_IF_NULL(modified);
138 // If we are under a cache op, then save ourselves to the cached op stack.
139 if (is_cached_) {
140 AddToCachedNodeStack(node);
141 }
142 // Set total repeats and total epochs for the node
143 node->SetTotalRepeats(num_repeats_);
144 node->SetNumEpochs(num_epochs_);
145 return Status::OK();
146 }
147
148 #ifndef ENABLE_ANDROID
149 // CacheOp removes previous leaf ops and replaces them with itself
VisitAfter(std::shared_ptr<CacheNode> node,bool * const modified)150 Status RepeatPass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
151 RETURN_UNEXPECTED_IF_NULL(node);
152 RETURN_UNEXPECTED_IF_NULL(modified);
153 is_cached_ = false;
154
155 // if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops.
156 // So that those cached nodes become 1-time use (up to eoe), never repeated. Instead
157 // the repeating behaviours shall be invoked against the cache op.
158 std::shared_ptr<DatasetNode> cached_node = PopFromCachedNodeStack();
159 while (cached_node != nullptr) {
160 int32_t cached_op_total_repeats = cached_node->GetTotalRepeats() / num_repeats_;
161 cached_node->SetTotalRepeats(cached_op_total_repeats);
162 // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1
163 cached_node->SetNumEpochs(1);
164 cached_node = PopFromCachedNodeStack();
165 }
166
167 node->SetTotalRepeats(num_repeats_);
168 node->SetNumEpochs(num_epochs_);
169 return Status::OK();
170 }
171
172 // Turns off the tracking for operations under merge op
VisitAfter(std::shared_ptr<CacheMergeNode> node,bool * const modified)173 Status RepeatPass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
174 RETURN_UNEXPECTED_IF_NULL(node);
175 RETURN_UNEXPECTED_IF_NULL(modified);
176 // If there was not any repeat in the merge cache miss leg, then the cache_lookup
177 // would not have been consumed yet. In that case, we need to set its total repeats for it.
178 if (cache_lookup_) {
179 cache_lookup_->SetTotalRepeats(num_repeats_);
180 cache_lookup_->SetNumEpochs(num_epochs_);
181 }
182 node->SetTotalRepeats(num_repeats_);
183 node->SetNumEpochs(num_epochs_);
184 cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
185 is_merge_ = false;
186 return Status::OK();
187 }
188
189 // Saves the lookup up in case it needs to be referenced by a repeat
VisitAfter(std::shared_ptr<CacheLookupNode> node,bool * const modified)190 Status RepeatPass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
191 RETURN_UNEXPECTED_IF_NULL(node);
192 RETURN_UNEXPECTED_IF_NULL(modified);
193 if (!node->IsLeaf()) {
194 // By definition, the CacheLookup must be a leaf op. Make that clear here.
195 RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
196 }
197
198 // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
199 // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
200 // into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
201 // Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
202 // add the lookup to the eoe stack
203 cache_lookup_ = std::static_pointer_cast<DatasetNode>(node);
204
205 return Status::OK();
206 }
207 #endif
208
VisitAfter(std::shared_ptr<TransferNode> node,bool * const modified)209 Status RepeatPass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
210 RETURN_UNEXPECTED_IF_NULL(node);
211 RETURN_UNEXPECTED_IF_NULL(modified);
212 // Set total repeats and total epochs for the TransferNode
213 node->SetTotalRepeats(num_epochs_);
214 node->SetNumEpochs(num_epochs_);
215 return Status::OK();
216 }
217
218 // Adds an operator to the cached operator stack save area
AddToCachedNodeStack(const std::shared_ptr<DatasetNode> & node)219 void RepeatPass::AddToCachedNodeStack(const std::shared_ptr<DatasetNode> &node) {
220 if (node == nullptr) {
221 return;
222 }
223 cached_node_stacks_.push(node);
224 }
225
226 // Pops an operator from the cached operator stack save area
PopFromCachedNodeStack()227 std::shared_ptr<DatasetNode> RepeatPass::PopFromCachedNodeStack() {
228 std::shared_ptr<DatasetNode> top_node = nullptr;
229 if (!cached_node_stacks_.empty()) {
230 top_node = cached_node_stacks_.top();
231 cached_node_stacks_.pop();
232 }
233 return top_node;
234 }
235 } // namespace dataset
236 } // namespace mindspore
237