1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/opt/pass.h"
18 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
19 #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
20 #ifndef ENABLE_ANDROID
21 #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
22 #endif
23 #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
24 #ifndef ENABLE_ANDROID
25 #include "minddata/dataset/engine/ir/datasetops/cache_node.h"
26 #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
27 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
28 #endif
29 #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
30 #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
31 #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
32 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
33 #include "minddata/dataset/engine/ir/datasetops/project_node.h"
34 #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
35 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
36 #include "minddata/dataset/engine/ir/datasetops/root_node.h"
37 #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
38 #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
39 #ifndef ENABLE_ANDROID
40 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
41 #endif
42 #ifdef ENABLE_PYTHON
43 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
44 #endif
45 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
46 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
47 #ifdef ENABLE_PYTHON
48 #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
49 #endif
50 #include "minddata/dataset/engine/ir/datasetops/take_node.h"
51 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
52 #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
53
54 namespace mindspore {
55 namespace dataset {
56
57 // Driver method for TreePass
Run(std::shared_ptr<DatasetNode> root_ir,bool * const modified)58 Status IRTreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
59 if (root_ir == nullptr || modified == nullptr) {
60 return Status(StatusCode::kMDUnexpectedError, "Null pointer passed to TreePass");
61 }
62 // Initialize modified flag
63 *modified = false;
64 return this->RunOnTree(root_ir, modified);
65 }
66
67 // Driver method for NodePass
Run(std::shared_ptr<DatasetNode> root_ir,bool * const modified)68 Status IRNodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
69 if (root_ir == nullptr || modified == nullptr) {
70 return Status(StatusCode::kMDUnexpectedError, "Null pointer passed to NodePass");
71 }
72 // Initialize modified flag
73 *modified = false;
74 if (traversalOrder_ == Order::DFS) {
75 // DFS
76 return DFSNodeVisit(root_ir, modified);
77 } else if (traversalOrder_ == Order::BFS) {
78 // BFS
79 return BFSNodeVisit(root_ir, modified);
80 }
81 return Status::OK();
82 }
83
84 // Helper function to perform DFS visit
DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir,bool * const modified)85 Status IRNodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified) {
86 bool m = false;
87
88 RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
89 *modified = *modified || m;
90 for (const auto &c : node_ir->Children()) {
91 RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m));
92 *modified = *modified || m;
93 }
94 RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m));
95 *modified = *modified || m;
96 return Status::OK();
97 }
98
99 // Helper function to perform BFS visit
BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir,bool * const modified)100 Status IRNodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified) {
101 bool m = false;
102
103 // Initialize bfs queue with root
104 std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
105 bfsQueue.push(node_ir);
106
107 // BFS loop
108 while (!bfsQueue.empty()) {
109 // Pop the front of the bfs queue
110 auto curNode = bfsQueue.front();
111 bfsQueue.pop();
112
113 // Run node pass
114 RETURN_IF_NOT_OK(curNode->Accept(this, &m));
115 *modified = *modified || m;
116
117 // Push children into bfs queue
118 for (const auto &c : curNode->Children()) {
119 bfsQueue.push(c);
120 }
121 }
122 return Status::OK();
123 }
124
125 // For non-leaf IR node
Visit(std::shared_ptr<BatchNode> node,bool * const modified)126 Status IRNodePass::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
127 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
128 }
VisitAfter(std::shared_ptr<BatchNode> node,bool * const modified)129 Status IRNodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *const modified) {
130 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
131 }
Visit(std::shared_ptr<BucketBatchByLengthNode> node,bool * const modified)132 Status IRNodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified) {
133 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
134 }
VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node,bool * const modified)135 Status IRNodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified) {
136 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
137 }
Visit(std::shared_ptr<BuildVocabNode> node,bool * const modified)138 Status IRNodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
139 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
140 }
VisitAfter(std::shared_ptr<BuildVocabNode> node,bool * const modified)141 Status IRNodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
142 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
143 }
Visit(std::shared_ptr<ConcatNode> node,bool * const modified)144 Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified) {
145 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
146 }
VisitAfter(std::shared_ptr<ConcatNode> node,bool * const modified)147 Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) {
148 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
149 }
150 #ifndef ENABLE_ANDROID
Visit(std::shared_ptr<CacheLookupNode> node,bool * const modified)151 Status IRNodePass::Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
152 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
153 }
VisitAfter(std::shared_ptr<CacheLookupNode> node,bool * const modified)154 Status IRNodePass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
155 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
156 }
Visit(std::shared_ptr<CacheMergeNode> node,bool * const modified)157 Status IRNodePass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
158 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
159 }
VisitAfter(std::shared_ptr<CacheMergeNode> node,bool * const modified)160 Status IRNodePass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
161 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
162 }
Visit(std::shared_ptr<CacheNode> node,bool * const modified)163 Status IRNodePass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
164 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
165 }
VisitAfter(std::shared_ptr<CacheNode> node,bool * const modified)166 Status IRNodePass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
167 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
168 }
169 #endif
Visit(std::shared_ptr<EpochCtrlNode> node,bool * const modified)170 Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
171 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
172 }
VisitAfter(std::shared_ptr<EpochCtrlNode> node,bool * const modified)173 Status IRNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
174 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
175 }
Visit(std::shared_ptr<FilterNode> node,bool * const modified)176 Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) {
177 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
178 }
VisitAfter(std::shared_ptr<FilterNode> node,bool * const modified)179 Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified) {
180 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
181 }
182 #ifdef ENABLE_PYTHON
Visit(std::shared_ptr<GeneratorNode> node,bool * const modified)183 Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
184 return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
185 }
VisitAfter(std::shared_ptr<GeneratorNode> node,bool * const modified)186 Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) {
187 return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
188 }
189 #endif
Visit(std::shared_ptr<MapNode> node,bool * const modified)190 Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
191 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
192 }
VisitAfter(std::shared_ptr<MapNode> node,bool * const modified)193 Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *const modified) {
194 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
195 }
196 #ifndef ENABLE_ANDROID
Visit(std::shared_ptr<MindDataNode> node,bool * const modified)197 Status IRNodePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
198 return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
199 }
VisitAfter(std::shared_ptr<MindDataNode> node,bool * const modified)200 Status IRNodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified) {
201 return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
202 }
203 #endif
Visit(std::shared_ptr<ProjectNode> node,bool * const modified)204 Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) {
205 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
206 }
VisitAfter(std::shared_ptr<ProjectNode> node,bool * const modified)207 Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified) {
208 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
209 }
Visit(std::shared_ptr<RandomNode> node,bool * const modified)210 Status IRNodePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
211 return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
212 }
VisitAfter(std::shared_ptr<RandomNode> node,bool * const modified)213 Status IRNodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified) {
214 return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
215 }
Visit(std::shared_ptr<RenameNode> node,bool * const modified)216 Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *const modified) {
217 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
218 }
VisitAfter(std::shared_ptr<RenameNode> node,bool * const modified)219 Status IRNodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified) {
220 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
221 }
Visit(std::shared_ptr<RepeatNode> node,bool * const modified)222 Status IRNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
223 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
224 }
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)225 Status IRNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
226 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
227 }
Visit(std::shared_ptr<RootNode> node,bool * const modified)228 Status IRNodePass::Visit(std::shared_ptr<RootNode> node, bool *const modified) {
229 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
230 }
VisitAfter(std::shared_ptr<RootNode> node,bool * const modified)231 Status IRNodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *const modified) {
232 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
233 }
Visit(std::shared_ptr<ShuffleNode> node,bool * const modified)234 Status IRNodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *const modified) {
235 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
236 }
VisitAfter(std::shared_ptr<ShuffleNode> node,bool * const modified)237 Status IRNodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *const modified) {
238 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
239 }
Visit(std::shared_ptr<SkipNode> node,bool * const modified)240 Status IRNodePass::Visit(std::shared_ptr<SkipNode> node, bool *const modified) {
241 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
242 }
VisitAfter(std::shared_ptr<SkipNode> node,bool * const modified)243 Status IRNodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified) {
244 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
245 }
Visit(std::shared_ptr<TakeNode> node,bool * const modified)246 Status IRNodePass::Visit(std::shared_ptr<TakeNode> node, bool *const modified) {
247 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
248 }
VisitAfter(std::shared_ptr<TakeNode> node,bool * const modified)249 Status IRNodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *const modified) {
250 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
251 }
Visit(std::shared_ptr<TFRecordNode> node,bool * const modified)252 Status IRNodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *const modified) {
253 return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
254 }
VisitAfter(std::shared_ptr<TFRecordNode> node,bool * const modified)255 Status IRNodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified) {
256 return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
257 }
Visit(std::shared_ptr<TransferNode> node,bool * const modified)258 Status IRNodePass::Visit(std::shared_ptr<TransferNode> node, bool *const modified) {
259 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
260 }
VisitAfter(std::shared_ptr<TransferNode> node,bool * const modified)261 Status IRNodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
262 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
263 }
Visit(std::shared_ptr<ZipNode> node,bool * const modified)264 Status IRNodePass::Visit(std::shared_ptr<ZipNode> node, bool *const modified) {
265 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
266 }
VisitAfter(std::shared_ptr<ZipNode> node,bool * const modified)267 Status IRNodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *const modified) {
268 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
269 }
270 #ifdef ENABLE_PYTHON
Visit(std::shared_ptr<SyncWaitNode> node,bool * const modified)271 Status IRNodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *const modified) {
272 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
273 }
VisitAfter(std::shared_ptr<SyncWaitNode> node,bool * const modified)274 Status IRNodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *const modified) {
275 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
276 }
277 #endif
278 #ifndef ENABLE_ANDROID
Visit(std::shared_ptr<BuildSentenceVocabNode> node,bool * const modified)279 Status IRNodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
280 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
281 }
VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node,bool * const modified)282 Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
283 return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
284 }
285 #endif
286
287 // leaf-IR Node
Visit(std::shared_ptr<MappableSourceNode> node,bool * const modified)288 Status IRNodePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
289 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
290 }
291
Visit(std::shared_ptr<NonMappableSourceNode> node,bool * const modified)292 Status IRNodePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
293 return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
294 }
295 } // namespace dataset
296 } // namespace mindspore
297