• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/opt/pass.h"
18 #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
19 #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
20 #ifndef ENABLE_ANDROID
21 #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
22 #endif
23 #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
24 #ifndef ENABLE_ANDROID
25 #include "minddata/dataset/engine/ir/datasetops/cache_node.h"
26 #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
27 #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
28 #endif
29 #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
30 #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
31 #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
32 #include "minddata/dataset/engine/ir/datasetops/map_node.h"
33 #include "minddata/dataset/engine/ir/datasetops/project_node.h"
34 #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
35 #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
36 #include "minddata/dataset/engine/ir/datasetops/root_node.h"
37 #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
38 #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
39 #ifndef ENABLE_ANDROID
40 #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
41 #endif
42 #ifdef ENABLE_PYTHON
43 #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
44 #endif
45 #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
46 #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
47 #ifdef ENABLE_PYTHON
48 #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
49 #endif
50 #include "minddata/dataset/engine/ir/datasetops/take_node.h"
51 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
52 #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
53 
54 namespace mindspore {
55 namespace dataset {
56 
57 // Driver method for TreePass
Run(std::shared_ptr<DatasetNode> root_ir,bool * const modified)58 Status IRTreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
59   if (root_ir == nullptr || modified == nullptr) {
60     return Status(StatusCode::kMDUnexpectedError, "Null pointer passed to TreePass");
61   }
62   // Initialize modified flag
63   *modified = false;
64   return this->RunOnTree(root_ir, modified);
65 }
66 
67 // Driver method for NodePass
Run(std::shared_ptr<DatasetNode> root_ir,bool * const modified)68 Status IRNodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
69   if (root_ir == nullptr || modified == nullptr) {
70     return Status(StatusCode::kMDUnexpectedError, "Null pointer passed to NodePass");
71   }
72   // Initialize modified flag
73   *modified = false;
74   if (traversalOrder_ == Order::DFS) {
75     // DFS
76     return DFSNodeVisit(root_ir, modified);
77   } else if (traversalOrder_ == Order::BFS) {
78     // BFS
79     return BFSNodeVisit(root_ir, modified);
80   }
81   return Status::OK();
82 }
83 
84 // Helper function to perform DFS visit
DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir,bool * const modified)85 Status IRNodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified) {
86   bool m = false;
87 
88   RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
89   *modified = *modified || m;
90   for (const auto &c : node_ir->Children()) {
91     RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m));
92     *modified = *modified || m;
93   }
94   RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m));
95   *modified = *modified || m;
96   return Status::OK();
97 }
98 
99 // Helper function to perform BFS visit
BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir,bool * const modified)100 Status IRNodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified) {
101   bool m = false;
102 
103   // Initialize bfs queue with root
104   std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
105   bfsQueue.push(node_ir);
106 
107   // BFS loop
108   while (!bfsQueue.empty()) {
109     // Pop the front of the bfs queue
110     auto curNode = bfsQueue.front();
111     bfsQueue.pop();
112 
113     // Run node pass
114     RETURN_IF_NOT_OK(curNode->Accept(this, &m));
115     *modified = *modified || m;
116 
117     // Push children into bfs queue
118     for (const auto &c : curNode->Children()) {
119       bfsQueue.push(c);
120     }
121   }
122   return Status::OK();
123 }
124 
125 // For non-leaf IR node
Visit(std::shared_ptr<BatchNode> node,bool * const modified)126 Status IRNodePass::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
127   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
128 }
VisitAfter(std::shared_ptr<BatchNode> node,bool * const modified)129 Status IRNodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *const modified) {
130   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
131 }
Visit(std::shared_ptr<BucketBatchByLengthNode> node,bool * const modified)132 Status IRNodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified) {
133   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
134 }
VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node,bool * const modified)135 Status IRNodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified) {
136   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
137 }
Visit(std::shared_ptr<BuildVocabNode> node,bool * const modified)138 Status IRNodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
139   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
140 }
VisitAfter(std::shared_ptr<BuildVocabNode> node,bool * const modified)141 Status IRNodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
142   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
143 }
Visit(std::shared_ptr<ConcatNode> node,bool * const modified)144 Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified) {
145   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
146 }
VisitAfter(std::shared_ptr<ConcatNode> node,bool * const modified)147 Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) {
148   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
149 }
150 #ifndef ENABLE_ANDROID
Visit(std::shared_ptr<CacheLookupNode> node,bool * const modified)151 Status IRNodePass::Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
152   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
153 }
VisitAfter(std::shared_ptr<CacheLookupNode> node,bool * const modified)154 Status IRNodePass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
155   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
156 }
Visit(std::shared_ptr<CacheMergeNode> node,bool * const modified)157 Status IRNodePass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
158   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
159 }
VisitAfter(std::shared_ptr<CacheMergeNode> node,bool * const modified)160 Status IRNodePass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
161   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
162 }
Visit(std::shared_ptr<CacheNode> node,bool * const modified)163 Status IRNodePass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
164   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
165 }
VisitAfter(std::shared_ptr<CacheNode> node,bool * const modified)166 Status IRNodePass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
167   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
168 }
169 #endif
Visit(std::shared_ptr<EpochCtrlNode> node,bool * const modified)170 Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
171   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
172 }
VisitAfter(std::shared_ptr<EpochCtrlNode> node,bool * const modified)173 Status IRNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
174   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
175 }
Visit(std::shared_ptr<FilterNode> node,bool * const modified)176 Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) {
177   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
178 }
VisitAfter(std::shared_ptr<FilterNode> node,bool * const modified)179 Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified) {
180   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
181 }
182 #ifdef ENABLE_PYTHON
Visit(std::shared_ptr<GeneratorNode> node,bool * const modified)183 Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
184   return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
185 }
VisitAfter(std::shared_ptr<GeneratorNode> node,bool * const modified)186 Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) {
187   return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
188 }
189 #endif
Visit(std::shared_ptr<MapNode> node,bool * const modified)190 Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
191   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
192 }
VisitAfter(std::shared_ptr<MapNode> node,bool * const modified)193 Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *const modified) {
194   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
195 }
196 #ifndef ENABLE_ANDROID
Visit(std::shared_ptr<MindDataNode> node,bool * const modified)197 Status IRNodePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
198   return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
199 }
VisitAfter(std::shared_ptr<MindDataNode> node,bool * const modified)200 Status IRNodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified) {
201   return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
202 }
203 #endif
Visit(std::shared_ptr<ProjectNode> node,bool * const modified)204 Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) {
205   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
206 }
VisitAfter(std::shared_ptr<ProjectNode> node,bool * const modified)207 Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified) {
208   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
209 }
Visit(std::shared_ptr<RandomNode> node,bool * const modified)210 Status IRNodePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
211   return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
212 }
VisitAfter(std::shared_ptr<RandomNode> node,bool * const modified)213 Status IRNodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified) {
214   return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
215 }
Visit(std::shared_ptr<RenameNode> node,bool * const modified)216 Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *const modified) {
217   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
218 }
VisitAfter(std::shared_ptr<RenameNode> node,bool * const modified)219 Status IRNodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified) {
220   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
221 }
Visit(std::shared_ptr<RepeatNode> node,bool * const modified)222 Status IRNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
223   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
224 }
VisitAfter(std::shared_ptr<RepeatNode> node,bool * const modified)225 Status IRNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
226   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
227 }
Visit(std::shared_ptr<RootNode> node,bool * const modified)228 Status IRNodePass::Visit(std::shared_ptr<RootNode> node, bool *const modified) {
229   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
230 }
VisitAfter(std::shared_ptr<RootNode> node,bool * const modified)231 Status IRNodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *const modified) {
232   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
233 }
Visit(std::shared_ptr<ShuffleNode> node,bool * const modified)234 Status IRNodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *const modified) {
235   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
236 }
VisitAfter(std::shared_ptr<ShuffleNode> node,bool * const modified)237 Status IRNodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *const modified) {
238   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
239 }
Visit(std::shared_ptr<SkipNode> node,bool * const modified)240 Status IRNodePass::Visit(std::shared_ptr<SkipNode> node, bool *const modified) {
241   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
242 }
VisitAfter(std::shared_ptr<SkipNode> node,bool * const modified)243 Status IRNodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified) {
244   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
245 }
Visit(std::shared_ptr<TakeNode> node,bool * const modified)246 Status IRNodePass::Visit(std::shared_ptr<TakeNode> node, bool *const modified) {
247   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
248 }
VisitAfter(std::shared_ptr<TakeNode> node,bool * const modified)249 Status IRNodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *const modified) {
250   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
251 }
Visit(std::shared_ptr<TFRecordNode> node,bool * const modified)252 Status IRNodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *const modified) {
253   return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
254 }
VisitAfter(std::shared_ptr<TFRecordNode> node,bool * const modified)255 Status IRNodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified) {
256   return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
257 }
Visit(std::shared_ptr<TransferNode> node,bool * const modified)258 Status IRNodePass::Visit(std::shared_ptr<TransferNode> node, bool *const modified) {
259   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
260 }
VisitAfter(std::shared_ptr<TransferNode> node,bool * const modified)261 Status IRNodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
262   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
263 }
Visit(std::shared_ptr<ZipNode> node,bool * const modified)264 Status IRNodePass::Visit(std::shared_ptr<ZipNode> node, bool *const modified) {
265   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
266 }
VisitAfter(std::shared_ptr<ZipNode> node,bool * const modified)267 Status IRNodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *const modified) {
268   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
269 }
270 #ifdef ENABLE_PYTHON
Visit(std::shared_ptr<SyncWaitNode> node,bool * const modified)271 Status IRNodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *const modified) {
272   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
273 }
VisitAfter(std::shared_ptr<SyncWaitNode> node,bool * const modified)274 Status IRNodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *const modified) {
275   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
276 }
277 #endif
278 #ifndef ENABLE_ANDROID
Visit(std::shared_ptr<BuildSentenceVocabNode> node,bool * const modified)279 Status IRNodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
280   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
281 }
VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node,bool * const modified)282 Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
283   return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
284 }
285 #endif
286 
287 // leaf-IR Node
Visit(std::shared_ptr<MappableSourceNode> node,bool * const modified)288 Status IRNodePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
289   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
290 }
291 
Visit(std::shared_ptr<NonMappableSourceNode> node,bool * const modified)292 Status IRNodePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
293   return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
294 }
295 }  // namespace dataset
296 }  // namespace mindspore
297