• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
19 
20 #include <memory>
21 #include <stack>
22 #include <utility>
23 #include "minddata/dataset/engine/opt/pass.h"
24 
25 namespace mindspore {
26 namespace dataset {
27 
28 /// \class RepeatPass
29 /// \brief This is a post pass that calculate the number of repeats the pipeline needs to fetch the data.
30 class RepeatPass : public IRNodePass {
31  public:
32   using op_stack = std::stack<std::shared_ptr<DatasetNode>>;
33 
34   /// \brief Constructor
35   RepeatPass();
36 
37   /// \brief Destructor
38   ~RepeatPass() = default;
39 
40   /// \brief Identifies the subtree below this node as being in a repeated path of the tree.
41   /// \param[in] node The node being visited
42   /// \param[in,out] modified Indicator if the node was changed at all
43   /// \return Status The status code returned
44   Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override;
45 
46   /// \brief Identifies the subtree below this node as being in a repeated path of the tree.
47   /// \param[in] node The node being visited
48   /// \param[in,out] modified Indicator if the node was changed at all
49   /// \return Status The status code returned
50   Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
51 
52 #ifndef ENABLE_ANDROID
53   /// \brief Identifies the subtree below this node as being in a cache merge path
54   /// \param[in] node The node being visited
55   /// \param[in,out] modified Indicator if the node was changed at all
56   /// \return Status The status code returned
57   Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) override;
58 
59   /// \brief Identifies the subtree below this node as being cached
60   /// \param[in] node The node being visited
61   /// \param[in,out] modified Indicator if the node was changed at all
62   /// \return Status The status code returned
63   Status Visit(std::shared_ptr<CacheNode> node, bool *const modified) override;
64 #endif
65 
66   /// \brief Hooks up any identified eoe nodes under this repeat.
67   /// \param[in] node The node being visited
68   /// \param[in,out] modified Indicator if the node was changed at all
69   /// \return Status The status code returned
70   Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override;
71 
72   /// \brief Hooks up any identified eoe nodes under this repeat.
73   /// \param[in] node The node being visited
74   /// \param[in,out] modified Indicator if the node was changed at all
75   /// \return Status The status code returned
76   Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
77 
78 #ifndef ENABLE_ANDROID
79   /// \brief CacheNode removes previous leaf ops and replaces them with itself
80   /// \param[in] node The node being visited
81   /// \param[in,out] modified Indicator if the node was changed at all
82   /// \return Status The status code returned
83   Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) override;
84 
85   /// \brief Turns off the tracking for operations under merge op
86   /// \param[in] node The node being visited
87   /// \param[in,out] modified Indicator if the node was changed at all
88   /// \return Status The status code returned
89   Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) override;
90 
91   /// \brief Saves the lookup up in case it needs to be referenced by a repeat
92   /// \param[in] node The node being visited
93   /// \param[in,out] modified Indicator if the node was changed at all
94   /// \return Status The status code returned
95   Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) override;
96 #endif
97 
98   /// \brief Sets the epoch count for TransferNode
99   /// \param[in] node The node being visited
100   /// \param[in,out] modified Indicator if the node was changed at all
101   /// \return Status The status code returned
102   Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override;
103 
104   /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
105   ///     for use with a controlling repeat above it.
106   /// \param[in] node The node being visited
107   /// \param[in,out] modified Indicator if the node was changed at all
108   /// \return Status The status code returned
109   Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override;
110 
111  private:
112   /// \brief Adds an operator to the cached stack save area
113   /// \param node - The dataset node to add to cached stack
114   /// \return Status The status code returned
115   void AddToCachedNodeStack(const std::shared_ptr<DatasetNode> &node);
116 
117   /// \brief Pops an operator from the cached stack save area
118   /// \return shared_ptr to the popped dataset node
119   std::shared_ptr<DatasetNode> PopFromCachedNodeStack();
120 
121   bool is_merge_;                              // T/F if we are processing under a cache merge node
122   bool is_cached_;                             // T/F is we are processing under a cache node
123   int32_t num_repeats_;                        // A multiplier to the total number of repeats
124   int32_t num_epochs_;                         // To save the total number of epochs
125   op_stack cached_node_stacks_;                // A save area for operators under a cache node
126   std::shared_ptr<DatasetNode> cache_lookup_;  // A save area for a cache lookup node
127 };
128 }  // namespace dataset
129 }  // namespace mindspore
130 
131 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
132