• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/opt/post/repeat_pass.h"
18 
19 #include <memory>
20 
21 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
22 #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
23 #include "minddata/dataset/engine/ir/datasetops/cache_node.h"
24 #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
25 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
26 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 
RepeatPass()31 RepeatPass::RepeatPass()
32     : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {}
33 
34 // Identifies the subtree below this node as being in a repeated path of the tree.
Visit(std::shared_ptr<RepeatNode> node,bool * const modified)35 Status RepeatPass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
36   RETURN_UNEXPECTED_IF_NULL(node);
37   RETURN_UNEXPECTED_IF_NULL(modified);
38   // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
39   // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
40   if (node->Count() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
41     num_repeats_ = -num_repeats_;
42   }
43   // This RepeatOp and its descendent nodes should be repeated for another num_repeats() times.
44   //
45   // Consider this example:
46   // tfreader --> map --> repeat(2) --> epoch ctrl(3)
47   // num_repeats_ is originally 3, after this repeat(2), num_repeats_ becomes 6 (2*3),
48   // meaning repeat op should be set to read 6 times (2*3), do does map op and tfreader op.
49   //
50   // Another example:
51   // tfreader --> repeat1(3) --> map --> repeat2(2) --> epoch ctrl(4)
52   // num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4),
53   // meaning repeat2 and map op should be set to read 8 times (2*4).
54   // Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times.
55   num_repeats_ *= node->Count();
56   return Status::OK();
57 }
58 
59 // Identifies the subtree below this node as being in a repeated path of the tree.
Visit(std::shared_ptr<EpochCtrlNode> node,bool * const modified)60 Status RepeatPass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
61   RETURN_UNEXPECTED_IF_NULL(node);
62   RETURN_UNEXPECTED_IF_NULL(modified);
63   // Get the total number of epochs from the EpochCtrlOp parameter
64   num_epochs_ = node->Count();
65   // Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
66   // For example: tfreader --> epoch ctrl(3)
67   // num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3),
68   // meaning epoch ctrl op should be set to read 3 times (1*3), so does tfreader op.
69   num_repeats_ *= num_epochs_;
70   return Status::OK();
71 }
72 
73 #ifndef ENABLE_ANDROID
74 // Identifies the subtree below this node as being in a cache merge path
Visit(std::shared_ptr<CacheMergeNode> node,bool * const modified)75 Status RepeatPass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
76   RETURN_UNEXPECTED_IF_NULL(node);
77   RETURN_UNEXPECTED_IF_NULL(modified);
78   // Turn on the flag that we're under a merge op
79   is_merge_ = true;
80   return Status::OK();
81 }
82 
83 // Identifies the subtree below this node as being cached
Visit(std::shared_ptr<CacheNode> node,bool * const modified)84 Status RepeatPass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
85   RETURN_UNEXPECTED_IF_NULL(node);
86   RETURN_UNEXPECTED_IF_NULL(modified);
87   // Turn on the flag that we're under a merge op
88   is_cached_ = true;
89   return Status::OK();
90 }
91 #endif
92 
93 // Hooks up any identified eoe nodes under this repeat.
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)94 Status RepeatPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
95   RETURN_UNEXPECTED_IF_NULL(node);
96   RETURN_UNEXPECTED_IF_NULL(modified);
97   // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
98   // and set its total repeats. It is important that the op is removed from the save area,
99   // because the merge op above us may also take action on it later for a different case when
100   // there is no repeat in the merge leg.
101   if (is_merge_ && cache_lookup_) {
102     cache_lookup_->SetTotalRepeats(num_repeats_);
103     cache_lookup_->SetNumEpochs(num_epochs_);
104     cache_lookup_.reset();
105   }
106 
107   if (is_cached_) {
108     AddToCachedNodeStack(node);
109   }
110   node->SetTotalRepeats(num_repeats_);
111   node->SetNumEpochs(num_epochs_);
112   // We finish the walk of this RepeatOp's descendent nodes.
113   // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
114   // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
115   // so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
116   CHECK_FAIL_RETURN_UNEXPECTED(node->Count() != 0, "Invalid data, the number of node can't be 0.");
117   num_repeats_ /= node->Count();
118   return Status::OK();
119 }
120 
121 // Hooks up any identified eoe nodes under this repeat.
VisitAfter(std::shared_ptr<EpochCtrlNode> node,bool * const modified)122 Status RepeatPass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
123   RETURN_UNEXPECTED_IF_NULL(node);
124   RETURN_UNEXPECTED_IF_NULL(modified);
125   CHECK_FAIL_RETURN_UNEXPECTED(node->Count() != 0, "Invalid data, the number of node can't be 0.");
126   node->SetTotalRepeats(num_repeats_);
127   node->SetNumEpochs(num_epochs_);
128   // We finish the walk of this EpochCtrl's descendent nodes.
129   num_repeats_ /= node->Count();
130   return Status::OK();
131 }
132 
133 // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
134 // for use with a controlling repeat above it.
VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)135 Status RepeatPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
136   RETURN_UNEXPECTED_IF_NULL(node);
137   RETURN_UNEXPECTED_IF_NULL(modified);
138   // If we are under a cache op, then save ourselves to the cached op stack.
139   if (is_cached_) {
140     AddToCachedNodeStack(node);
141   }
142   // Set total repeats and total epochs for the node
143   node->SetTotalRepeats(num_repeats_);
144   node->SetNumEpochs(num_epochs_);
145   return Status::OK();
146 }
147 
148 #ifndef ENABLE_ANDROID
149 // CacheOp removes previous leaf ops and replaces them with itself
VisitAfter(std::shared_ptr<CacheNode> node,bool * const modified)150 Status RepeatPass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
151   RETURN_UNEXPECTED_IF_NULL(node);
152   RETURN_UNEXPECTED_IF_NULL(modified);
153   is_cached_ = false;
154 
155   // if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops.
156   // So that those cached nodes become 1-time use (up to eoe), never repeated.  Instead
157   // the repeating behaviours shall be invoked against the cache op.
158   std::shared_ptr<DatasetNode> cached_node = PopFromCachedNodeStack();
159   while (cached_node != nullptr) {
160     int32_t cached_op_total_repeats = cached_node->GetTotalRepeats() / num_repeats_;
161     cached_node->SetTotalRepeats(cached_op_total_repeats);
162     // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1
163     cached_node->SetNumEpochs(1);
164     cached_node = PopFromCachedNodeStack();
165   }
166 
167   node->SetTotalRepeats(num_repeats_);
168   node->SetNumEpochs(num_epochs_);
169   return Status::OK();
170 }
171 
172 // Turns off the tracking for operations under merge op
VisitAfter(std::shared_ptr<CacheMergeNode> node,bool * const modified)173 Status RepeatPass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
174   RETURN_UNEXPECTED_IF_NULL(node);
175   RETURN_UNEXPECTED_IF_NULL(modified);
176   // If there was not any repeat in the merge cache miss leg, then the cache_lookup
177   // would not have been consumed yet.  In that case, we need to set its total repeats for it.
178   if (cache_lookup_) {
179     cache_lookup_->SetTotalRepeats(num_repeats_);
180     cache_lookup_->SetNumEpochs(num_epochs_);
181   }
182   node->SetTotalRepeats(num_repeats_);
183   node->SetNumEpochs(num_epochs_);
184   cache_lookup_.reset();  // If we are not repeated then the saved lookup is no longer needed or used
185   is_merge_ = false;
186   return Status::OK();
187 }
188 
189 // Saves the lookup up in case it needs to be referenced by a repeat
VisitAfter(std::shared_ptr<CacheLookupNode> node,bool * const modified)190 Status RepeatPass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
191   RETURN_UNEXPECTED_IF_NULL(node);
192   RETURN_UNEXPECTED_IF_NULL(modified);
193   if (!node->IsLeaf()) {
194     // By definition, the CacheLookup must be a leaf op.  Make that clear here.
195     RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
196   }
197 
198   // save the lookup op.  There could be a repeat in the cache miss leg of the merge op, in which case we
199   // may still need to be flagged as a repeating leaf.  We can't decide that here though, so save ourself
200   // into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
201   // Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
202   // add the lookup to the eoe stack
203   cache_lookup_ = std::static_pointer_cast<DatasetNode>(node);
204 
205   return Status::OK();
206 }
207 #endif
208 
VisitAfter(std::shared_ptr<TransferNode> node,bool * const modified)209 Status RepeatPass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
210   RETURN_UNEXPECTED_IF_NULL(node);
211   RETURN_UNEXPECTED_IF_NULL(modified);
212   // Set total repeats and total epochs for the TransferNode
213   node->SetTotalRepeats(num_epochs_);
214   node->SetNumEpochs(num_epochs_);
215   return Status::OK();
216 }
217 
218 // Adds an operator to the cached operator stack save area
AddToCachedNodeStack(const std::shared_ptr<DatasetNode> & node)219 void RepeatPass::AddToCachedNodeStack(const std::shared_ptr<DatasetNode> &node) {
220   if (node == nullptr) {
221     return;
222   }
223   cached_node_stacks_.push(node);
224 }
225 
226 // Pops an operator from the cached operator stack save area
PopFromCachedNodeStack()227 std::shared_ptr<DatasetNode> RepeatPass::PopFromCachedNodeStack() {
228   std::shared_ptr<DatasetNode> top_node = nullptr;
229   if (!cached_node_stacks_.empty()) {
230     top_node = cached_node_stacks_.top();
231     cached_node_stacks_.pop();
232   }
233   return top_node;
234 }
235 }  // namespace dataset
236 }  // namespace mindspore
237