• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_
19 
20 #include <memory>
21 #include <queue>
22 
23 #include "minddata/dataset/engine/execution_tree.h"
24 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
25 #include "minddata/dataset/util/status.h"
26 
27 namespace mindspore {
28 namespace dataset {
29 // Non-leaf IR node
30 class BatchNode;
31 class BucketBatchByLengthNode;
32 class BuildVocabNode;
33 #ifndef ENABLE_ANDROID
34 class CacheLookupNode;
35 class CacheMergeNode;
36 class CacheNode;
37 #endif
38 class ConcatNode;
39 class EpochCtrlNode;
40 class FilterNode;
41 class MapNode;
42 class ProjectNode;
43 class RenameNode;
44 class RepeatNode;
45 class RootNode;
46 class ShuffleNode;
47 class SkipNode;
48 class TakeNode;
49 class TFRecordNode;
50 class TransferNode;
51 class ZipNode;
52 #ifdef ENABLE_PYTHON
53 class SyncWaitNode;
54 #endif
55 #ifndef ENABLE_ANDROID
56 class BuildSentenceVocabNode;
57 #endif
58 // Leaf IR node
59 class AlbumNode;
60 class CelebANode;
61 class Cifar100Node;
62 class Cifar10Node;
63 class CocoNode;
64 class ImageFolderNode;
65 class ManifestNode;
66 class MnistNode;
67 class RandomNode;
68 class VOCNode;
69 #ifdef ENABLE_PYTHON
70 class GeneratorNode;
71 #endif
72 #ifndef ENABLE_ANDROID
73 class CLUENode;
74 class CSVNode;
75 class MindDataNode;
76 class TextFileNode;
77 class TFRecordNode;
78 #endif
79 
80 // The base class Pass is the basic unit of tree transformation.
81 // The actual implementation of the passes will be derived from here.
82 class IRPass : public std::enable_shared_from_this<IRPass> {
83  public:
84   // Run the transformation pass against the IR tree.
85   // @param root_ir - Pointer to the IR tree to be transformed.
86   // @param modified - Pointer to the modified flag,
87   virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) = 0;
88 
89   virtual ~IRPass() = default;
90 };
91 
92 // IRTreePass is a basic Pass class which performs transformation on IR tree directly.
93 class IRTreePass : public IRPass {
94  public:
95   /// \brief Run the transformation pass against the IR tree.
96   /// \param[in,out] root_ir Pointer to the IR tree to be transformed.
97   /// \param[in,out] modified Indicate if the tree was modified
98   Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
99 
100   /// \brief Derived classes may implement the runOnTree function to implement tree transformation.
101   ///     "modified" flag needs to be set to true if tree is modified during the pass execution.
102   /// \param[in,out] tree The tree to operate on.
103   /// \param[in,out] Indicate if the tree was modified.
104   /// \return Status The status code returned
RunOnTree(std::shared_ptr<DatasetNode> root_ir,bool * const modified)105   virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) { return Status::OK(); }
106 };
107 
108 // IRNodePass is a base Pass class which performs transformation on node visiting.
109 // IRNodePass implements Visitor design pattern.
110 // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
111 // and the other when all the descending nodes are visited.
112 // Actual transformation is done by implementing a new derived class of IRNodePass.
113 // The derived class will implement the method Visit()/VisitAfter() passing specified node types
114 // it wants to action on them, overriding the ones defined in IRNodePass.
115 // If the derived class wants to perform the same action on all node types,
116 // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
117 // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
118 // to call the Visit()/VisitAfter() in this parent IRNodePass class.
119 class IRNodePass : public IRPass {
120  public:
121   // Tree traversal order
122   enum Order { DFS, BFS };
123 
124   // Constructor
125   // Default DFS traversal
126   explicit IRNodePass(Order order = Order::DFS) { traversalOrder_ = order; }
127 
128   ~IRNodePass() = default;
129 
130   /// \brief Run the transformation pass against the IR tree
131   /// \param[in,out] root_ir Pointer to the IR tree to be transformed
132   /// \param[in,out] modified Indicator if the tree was changed
133   Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
134 
135   /// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down
136   ///     a tree traversal.  "modified" flag needs to be set to true if node is modified during the pass execution
137   /// \param[in] node The node being visited
138   /// \param[out] modified Indicator if the node was changed at all
139   /// \return Status The status code returned
Visit(std::shared_ptr<DatasetNode> node,bool * const modified)140   virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) { return Status::OK(); }
141 
142   /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation
143   ///     "modified" flag needs to be set to true if node is modified during the pass execution
144   /// \param[in] node The node being visited
145   /// \param[out] modified Indicator if the node was changed at all.
146   /// \return Status The status code returned
VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)147   virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) { return Status::OK(); }
148 
149   // Visit()/VisitAfter() method to be overridden.
150   // These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here.
151   // Their implementation are in .cc file to avoid adding the include files of those derived classes.
152   // The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of
153   // the derived classes. With this technique, the transformation classes derived from NodePass needs only to
154   // implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes
155   // of DatasetNode in the same way.
156   // Note that virtual template functions are not permitted in C++.
157   //
158   // Non-leaf IR node
159   virtual Status Visit(std::shared_ptr<BatchNode> node, bool *const modified);
160   virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *const modified);
161   virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified);
162   virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified);
163 #ifndef ENABLE_ANDROID
164   virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified);
165   virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified);
166 #endif
167   virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified);
168   virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified);
169   virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified);
170   virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified);
171 #ifndef ENABLE_ANDROID
172   virtual Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified);
173   virtual Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified);
174   virtual Status Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified);
175   virtual Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified);
176   virtual Status Visit(std::shared_ptr<CacheNode> node, bool *const modified);
177   virtual Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified);
178 #endif
179   virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
180   virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
181   virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified);
182   virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified);
183 #ifdef ENABLE_PYTHON
184   virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified);
185   virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified);
186 #endif
187   virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified);
188   virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified);
189 #ifndef ENABLE_ANDROID
190   virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified);
191   virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified);
192 #endif
193   virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified);
194   virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified);
195   virtual Status Visit(std::shared_ptr<RandomNode> node, bool *const modified);
196   virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified);
197   virtual Status Visit(std::shared_ptr<RenameNode> node, bool *const modified);
198   virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified);
199   virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified);
200   virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified);
201   virtual Status Visit(std::shared_ptr<RootNode> node, bool *const modified);
202   virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *const modified);
203   virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *const modified);
204   virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *const modified);
205   virtual Status Visit(std::shared_ptr<SkipNode> node, bool *const modified);
206   virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified);
207 #ifdef ENABLE_PYTHON
208   virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *const modified);
209   virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *const modified);
210 #endif
211   virtual Status Visit(std::shared_ptr<TakeNode> node, bool *const modified);
212   virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *const modified);
213   virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *const modified);
214   virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified);
215   virtual Status Visit(std::shared_ptr<TransferNode> node, bool *const modified);
216   virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified);
217   virtual Status Visit(std::shared_ptr<ZipNode> node, bool *const modified);
218   virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *const modified);
219 
220   // leaf-IR Node
221   virtual Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified);
222   virtual Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified);
223 
224  private:
225   // Helper function to perform DFS visit
226   Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified);
227 
228   // Helper function to perform BFS visit
229   Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified);
230 
231   // Tree traversal order of the NodePass
232   Order traversalOrder_;
233 };
234 }  // namespace dataset
235 }  // namespace mindspore
236 
237 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_
238