1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ 19 20 #include <memory> 21 #include <stack> 22 #include <utility> 23 #include "minddata/dataset/engine/opt/pass.h" 24 25 namespace mindspore { 26 namespace dataset { 27 28 /// \class RepeatPass 29 /// \brief This is a post pass that calculate the number of repeats the pipeline needs to fetch the data. 30 class RepeatPass : public IRNodePass { 31 public: 32 using op_stack = std::stack<std::shared_ptr<DatasetNode>>; 33 34 /// \brief Constructor 35 RepeatPass(); 36 37 /// \brief Destructor 38 ~RepeatPass() = default; 39 40 /// \brief Identifies the subtree below this node as being in a repeated path of the tree. 41 /// \param[in] node The node being visited 42 /// \param[in,out] modified Indicator if the node was changed at all 43 /// \return Status The status code returned 44 Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override; 45 46 /// \brief Identifies the subtree below this node as being in a repeated path of the tree. 47 /// \param[in] node The node being visited 48 /// \param[in,out] modified Indicator if the node was changed at all 49 /// \return Status The status code returned 50 Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; 51 52 #ifndef ENABLE_ANDROID 53 /// \brief Identifies the subtree below this node as being in a cache merge path 54 /// \param[in] node The node being visited 55 /// \param[in,out] modified Indicator if the node was changed at all 56 /// \return Status The status code returned 57 Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) override; 58 59 /// \brief Identifies the subtree below this node as being cached 60 /// \param[in] node The node being visited 61 /// \param[in,out] modified Indicator if the node was changed at all 62 /// \return Status The status code returned 63 Status Visit(std::shared_ptr<CacheNode> node, bool *const modified) override; 64 #endif 65 66 /// \brief Hooks up any identified eoe nodes under this repeat. 67 /// \param[in] node The node being visited 68 /// \param[in,out] modified Indicator if the node was changed at all 69 /// \return Status The status code returned 70 Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override; 71 72 /// \brief Hooks up any identified eoe nodes under this repeat. 73 /// \param[in] node The node being visited 74 /// \param[in,out] modified Indicator if the node was changed at all 75 /// \return Status The status code returned 76 Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; 77 78 #ifndef ENABLE_ANDROID 79 /// \brief CacheNode removes previous leaf ops and replaces them with itself 80 /// \param[in] node The node being visited 81 /// \param[in,out] modified Indicator if the node was changed at all 82 /// \return Status The status code returned 83 Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) override; 84 85 /// \brief Turns off the tracking for operations under merge op 86 /// \param[in] node The node being visited 87 /// \param[in,out] modified Indicator if the node was changed at all 88 /// \return Status The status code returned 89 Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) override; 90 91 /// \brief Saves the lookup up in case it needs to be referenced by a repeat 92 /// \param[in] node The node being visited 93 /// \param[in,out] modified Indicator if the node was changed at all 94 /// \return Status The status code returned 95 Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) override; 96 #endif 97 98 /// \brief Sets the epoch count for TransferNode 99 /// \param[in] node The node being visited 100 /// \param[in,out] modified Indicator if the node was changed at all 101 /// \return Status The status code returned 102 Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override; 103 104 /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up 105 /// for use with a controlling repeat above it. 106 /// \param[in] node The node being visited 107 /// \param[in,out] modified Indicator if the node was changed at all 108 /// \return Status The status code returned 109 Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override; 110 111 private: 112 /// \brief Adds an operator to the cached stack save area 113 /// \param node - The dataset node to add to cached stack 114 /// \return Status The status code returned 115 void AddToCachedNodeStack(const std::shared_ptr<DatasetNode> &node); 116 117 /// \brief Pops an operator from the cached stack save area 118 /// \return shared_ptr to the popped dataset node 119 std::shared_ptr<DatasetNode> PopFromCachedNodeStack(); 120 121 bool is_merge_; // T/F if we are processing under a cache merge node 122 bool is_cached_; // T/F is we are processing under a cache node 123 int32_t num_repeats_; // A multiplier to the total number of repeats 124 int32_t num_epochs_; // To save the total number of epochs 125 op_stack cached_node_stacks_; // A save area for operators under a cache node 126 std::shared_ptr<DatasetNode> cache_lookup_; // A save area for a cache lookup node 127 }; 128 } // namespace dataset 129 } // namespace mindspore 130 131 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ 132