1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ 19 20 #include <memory> 21 #include <vector> 22 #include "minddata/dataset/engine/opt/pass.h" 23 24 namespace mindspore { 25 namespace dataset { 26 27 class DatasetOp; 28 29 /// \class RemovalPass removal_pass.h 30 /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which 31 /// nodes should be removed, and then removes them. 32 class NodeRemovalPass : public IRTreePass { 33 /// \class RemovalNodes 34 /// \brief This is a NodePass whose job is to identify which nodes should be removed. 35 /// It works in conjunction with the removal_pass. 36 class RemovalNodes : public IRNodePass { 37 public: 38 /// \brief Constructor 39 /// \param[in] removal_pass Raw pointer back to controlling tree pass 40 RemovalNodes(); 41 42 /// \brief Destructor 43 ~RemovalNodes() = default; 44 45 /// \brief Perform RepeatNode removal check 46 /// \param[in] node The node being visited 47 /// \param[in, out] modified Indicator if the node was changed at all 48 /// \return Status The status code returned 49 Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override; 50 51 /// \brief Perform SkipNode removal check 52 /// \param[in] node The node being visited 53 /// \param[in, out] modified Indicator if the node was changed at all 54 /// \return Status The status code returned 55 Status Visit(std::shared_ptr<SkipNode> node, bool *const modified) override; 56 57 /// \brief Perform TakeNode removal check 58 /// \param[in] node The node being visited 59 /// \param[in, out] modified Indicator if the node was changed at all 60 /// \return Status The status code returned 61 Status Visit(std::shared_ptr<TakeNode> node, bool *const modified) override; 62 63 /// \brief Getter 64 /// \return All the nodes to be removed nodes_to_remove()65 std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; } 66 67 private: 68 std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove_; 69 }; 70 71 public: 72 /// \brief Constructor 73 NodeRemovalPass(); 74 75 /// \brief Destructor 76 ~NodeRemovalPass() = default; 77 78 /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. 79 /// \param[in, out] tree The tree to operate on. 80 /// \param[in, out] Indicate of the tree was modified. 81 /// \return Status The status code returned 82 Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override; 83 }; 84 } // namespace dataset 85 } // namespace mindspore 86 87 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ 88