1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H 19 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 #include "minddata/dataset/engine/opt/pass.h" 24 25 namespace mindspore { 26 namespace dataset { 27 28 /// \class GeneratorNodePass repeat_pass.h 29 /// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references 30 /// to the eoe-producing (typically leaf) nodes underneath it. 31 class GeneratorNodePass : public IRNodePass { 32 public: 33 /// \brief Constructor 34 GeneratorNodePass(); 35 36 /// \brief Destructor 37 ~GeneratorNodePass() = default; 38 39 /// \brief Record the starting point to collect the Generator node 40 /// \param[in] node The node being visited 41 /// \param[in, out] modified Indicator if the node was changed at all 42 /// \return Status The status code returned 43 Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override; 44 45 /// \brief Record the starting point to collect the Generator node 46 /// \param[in] node The node being visited 47 /// \param[in, out] modified Indicator if the node was changed at all 48 /// \return Status The status code returned 49 Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; 50 51 /// \brief Add the Generator node to the set 52 /// \param[in] node The node being visited 53 /// \param[in, out] modified Indicator if the node was changed at all 54 /// \return Status The status code returned 55 Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override; 56 57 /// \brief Add the Generator node(s) from the set to this Repeat node for run-time processing 58 /// \param[in] node The node being visited 59 /// \param[in, out] modified Indicator if the node was changed at all 60 /// \return Status The status code returned 61 Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override; 62 63 /// \brief Add the Generator node(s) from the set to this EpochCtrl node for run-time processing 64 /// \param[in] node The node being visited 65 /// \param[in, out] modified Indicator if the node was changed at all 66 /// \return Status The status code returned 67 Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; 68 69 private: 70 std::vector<std::shared_ptr<RepeatNode>> repeat_ancestors_; 71 }; 72 } // namespace dataset 73 } // namespace mindspore 74 75 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H 76