1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_ 19 20 #include <memory> 21 #include <queue> 22 23 #include "minddata/dataset/engine/execution_tree.h" 24 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 25 #include "minddata/dataset/util/status.h" 26 27 namespace mindspore { 28 namespace dataset { 29 // Non-leaf IR node 30 class BatchNode; 31 class BucketBatchByLengthNode; 32 class BuildVocabNode; 33 #ifndef ENABLE_ANDROID 34 class CacheLookupNode; 35 class CacheMergeNode; 36 class CacheNode; 37 #endif 38 class ConcatNode; 39 class EpochCtrlNode; 40 class FilterNode; 41 class MapNode; 42 class ProjectNode; 43 class RenameNode; 44 class RepeatNode; 45 class RootNode; 46 class ShuffleNode; 47 class SkipNode; 48 class TakeNode; 49 class TFRecordNode; 50 class TransferNode; 51 class ZipNode; 52 #ifdef ENABLE_PYTHON 53 class SyncWaitNode; 54 #endif 55 #ifndef ENABLE_ANDROID 56 class BuildSentenceVocabNode; 57 #endif 58 // Leaf IR node 59 class AlbumNode; 60 class CelebANode; 61 class Cifar100Node; 62 class Cifar10Node; 63 class CocoNode; 64 class ImageFolderNode; 65 class ManifestNode; 66 class MnistNode; 67 class RandomNode; 68 class VOCNode; 69 #ifdef ENABLE_PYTHON 70 class GeneratorNode; 71 #endif 72 #ifndef ENABLE_ANDROID 73 class CLUENode; 74 class CSVNode; 75 class MindDataNode; 76 class TextFileNode; 77 class TFRecordNode; 78 #endif 79 80 // The base class Pass is the basic unit of tree transformation. 81 // The actual implementation of the passes will be derived from here. 82 class IRPass : public std::enable_shared_from_this<IRPass> { 83 public: 84 // Run the transformation pass against the IR tree. 85 // @param root_ir - Pointer to the IR tree to be transformed. 86 // @param modified - Pointer to the modified flag, 87 virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) = 0; 88 89 virtual ~IRPass() = default; 90 }; 91 92 // IRTreePass is a basic Pass class which performs transformation on IR tree directly. 93 class IRTreePass : public IRPass { 94 public: 95 /// \brief Run the transformation pass against the IR tree. 96 /// \param[in,out] root_ir Pointer to the IR tree to be transformed. 97 /// \param[in,out] modified Indicate if the tree was modified 98 Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final; 99 100 /// \brief Derived classes may implement the runOnTree function to implement tree transformation. 101 /// "modified" flag needs to be set to true if tree is modified during the pass execution. 102 /// \param[in,out] tree The tree to operate on. 103 /// \param[in,out] Indicate if the tree was modified. 104 /// \return Status The status code returned RunOnTree(std::shared_ptr<DatasetNode> root_ir,bool * const modified)105 virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) { return Status::OK(); } 106 }; 107 108 // IRNodePass is a base Pass class which performs transformation on node visiting. 109 // IRNodePass implements Visitor design pattern. 110 // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, 111 // and the other when all the descending nodes are visited. 112 // Actual transformation is done by implementing a new derived class of IRNodePass. 113 // The derived class will implement the method Visit()/VisitAfter() passing specified node types 114 // it wants to action on them, overriding the ones defined in IRNodePass. 115 // If the derived class wants to perform the same action on all node types, 116 // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. 117 // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back 118 // to call the Visit()/VisitAfter() in this parent IRNodePass class. 119 class IRNodePass : public IRPass { 120 public: 121 // Tree traversal order 122 enum Order { DFS, BFS }; 123 124 // Constructor 125 // Default DFS traversal 126 explicit IRNodePass(Order order = Order::DFS) { traversalOrder_ = order; } 127 128 ~IRNodePass() = default; 129 130 /// \brief Run the transformation pass against the IR tree 131 /// \param[in,out] root_ir Pointer to the IR tree to be transformed 132 /// \param[in,out] modified Indicator if the tree was changed 133 Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final; 134 135 /// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down 136 /// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution 137 /// \param[in] node The node being visited 138 /// \param[out] modified Indicator if the node was changed at all 139 /// \return Status The status code returned Visit(std::shared_ptr<DatasetNode> node,bool * const modified)140 virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) { return Status::OK(); } 141 142 /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation 143 /// "modified" flag needs to be set to true if node is modified during the pass execution 144 /// \param[in] node The node being visited 145 /// \param[out] modified Indicator if the node was changed at all. 146 /// \return Status The status code returned VisitAfter(std::shared_ptr<DatasetNode> node,bool * const modified)147 virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) { return Status::OK(); } 148 149 // Visit()/VisitAfter() method to be overridden. 150 // These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here. 151 // Their implementation are in .cc file to avoid adding the include files of those derived classes. 152 // The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of 153 // the derived classes. With this technique, the transformation classes derived from NodePass needs only to 154 // implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes 155 // of DatasetNode in the same way. 156 // Note that virtual template functions are not permitted in C++. 157 // 158 // Non-leaf IR node 159 virtual Status Visit(std::shared_ptr<BatchNode> node, bool *const modified); 160 virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *const modified); 161 virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified); 162 virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified); 163 #ifndef ENABLE_ANDROID 164 virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified); 165 virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified); 166 #endif 167 virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified); 168 virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified); 169 virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified); 170 virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified); 171 #ifndef ENABLE_ANDROID 172 virtual Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified); 173 virtual Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified); 174 virtual Status Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified); 175 virtual Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified); 176 virtual Status Visit(std::shared_ptr<CacheNode> node, bool *const modified); 177 virtual Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified); 178 #endif 179 virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified); 180 virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified); 181 virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified); 182 virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified); 183 #ifdef ENABLE_PYTHON 184 virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified); 185 virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified); 186 #endif 187 virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified); 188 virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified); 189 #ifndef ENABLE_ANDROID 190 virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified); 191 virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified); 192 #endif 193 virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified); 194 virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified); 195 virtual Status Visit(std::shared_ptr<RandomNode> node, bool *const modified); 196 virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified); 197 virtual Status Visit(std::shared_ptr<RenameNode> node, bool *const modified); 198 virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified); 199 virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified); 200 virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified); 201 virtual Status Visit(std::shared_ptr<RootNode> node, bool *const modified); 202 virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *const modified); 203 virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *const modified); 204 virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *const modified); 205 virtual Status Visit(std::shared_ptr<SkipNode> node, bool *const modified); 206 virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified); 207 #ifdef ENABLE_PYTHON 208 virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *const modified); 209 virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *const modified); 210 #endif 211 virtual Status Visit(std::shared_ptr<TakeNode> node, bool *const modified); 212 virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *const modified); 213 virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *const modified); 214 virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified); 215 virtual Status Visit(std::shared_ptr<TransferNode> node, bool *const modified); 216 virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified); 217 virtual Status Visit(std::shared_ptr<ZipNode> node, bool *const modified); 218 virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *const modified); 219 220 // leaf-IR Node 221 virtual Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified); 222 virtual Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified); 223 224 private: 225 // Helper function to perform DFS visit 226 Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified); 227 228 // Helper function to perform BFS visit 229 Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified); 230 231 // Tree traversal order of the NodePass 232 Order traversalOrder_; 233 }; 234 } // namespace dataset 235 } // namespace mindspore 236 237 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_ 238