• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
19 
20 #include <memory>
21 #include <vector>
22 #include "minddata/dataset/engine/opt/pass.h"
23 
24 namespace mindspore {
25 namespace dataset {
26 
27 class DatasetOp;
28 
29 /// \class RemovalPass removal_pass.h
30 /// \brief This is a tree pass that will remove nodes.  It uses removal_nodes to first identify which
31 ///     nodes should be removed, and then removes them.
32 class NodeRemovalPass : public IRTreePass {
33   /// \class RemovalNodes
34   /// \brief This is a NodePass whose job is to identify which nodes should be removed.
35   ///     It works in conjunction with the removal_pass.
36   class RemovalNodes : public IRNodePass {
37    public:
38     /// \brief Constructor
39     /// \param[in] removal_pass Raw pointer back to controlling tree pass
40     RemovalNodes();
41 
42     /// \brief Destructor
43     ~RemovalNodes() = default;
44 
45     /// \brief Perform RepeatNode removal check
46     /// \param[in] node The node being visited
47     /// \param[in, out] modified Indicator if the node was changed at all
48     /// \return Status The status code returned
49     Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override;
50 
51     /// \brief Perform SkipNode removal check
52     /// \param[in] node The node being visited
53     /// \param[in, out] modified Indicator if the node was changed at all
54     /// \return Status The status code returned
55     Status Visit(std::shared_ptr<SkipNode> node, bool *const modified) override;
56 
57     /// \brief Perform TakeNode removal check
58     /// \param[in] node The node being visited
59     /// \param[in, out] modified Indicator if the node was changed at all
60     /// \return Status The status code returned
61     Status Visit(std::shared_ptr<TakeNode> node, bool *const modified) override;
62 
63     /// \brief Getter
64     /// \return All the nodes to be removed
nodes_to_remove()65     std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; }
66 
67    private:
68     std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove_;
69   };
70 
71  public:
72   /// \brief Constructor
73   NodeRemovalPass();
74 
75   /// \brief Destructor
76   ~NodeRemovalPass() = default;
77 
78   /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
79   /// \param[in, out] tree The tree to operate on.
80   /// \param[in, out] Indicate of the tree was modified.
81   /// \return Status The status code returned
82   Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
83 };
84 }  // namespace dataset
85 }  // namespace mindspore
86 
87 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_
88