1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ 18 #define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ 19 20 #include <memory> 21 #include <vector> 22 #include "minddata/dataset/engine/opt/pass.h" 23 24 namespace mindspore { 25 namespace dataset { 26 27 class DatasetOp; 28 29 /// \class EpochInjectionPass epoch_ctrl_pass.h 30 /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api 31 /// parsing. 32 class EpochCtrlPass : public IRTreePass { 33 /// \class InjectionFinder 34 /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for 35 /// operators that need to be injected. It is run first by the main injection pass to find out what operators 36 /// it may need to inject. 37 class InjectionFinder : public IRNodePass { 38 public: 39 /// \brief Constructor 40 explicit InjectionFinder(std::shared_ptr<DatasetNode> node); 41 42 /// \brief Destructor 43 ~InjectionFinder() = default; 44 45 /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. 46 /// \param[in] node The node being visited 47 /// \param[in, out] modified Indicator if the node was changed at all 48 /// \return Status The status code returned 49 Status Visit(std::shared_ptr<RootNode> node, bool *const modified) override; 50 51 /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. 52 /// \param[in] node The node being visited 53 /// \param[in, out] modified Indicator if the node was changed at all 54 /// \return Status The status code returned 55 Status Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) override; 56 57 #ifndef ENABLE_ANDROID 58 /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection. 59 /// \param[in] node The node being visited 60 /// \param[in, out] modified Indicator if the node was changed at all 61 /// \return Status The status code returned 62 Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) override; 63 #endif 64 65 /// \brief Register the TransferNode for further action. 66 /// \param[in] node The node being visited 67 /// \param[in, out] modified Indicator if the node was changed at all 68 /// \return Status The status code returned 69 Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override; 70 71 /// \brief Getter injection_point()72 std::shared_ptr<DatasetNode> injection_point() { return injection_point_; } 73 74 /// \brief Getter num_epochs()75 int32_t num_epochs() { return num_epochs_; } 76 77 private: 78 std::shared_ptr<DatasetNode> injection_point_; 79 int32_t num_epochs_; 80 }; 81 82 public: 83 /// \brief Constructor 84 EpochCtrlPass(); 85 86 /// \brief Destructor 87 ~EpochCtrlPass() = default; 88 89 /// \brief Runs an injection pass to inject in operators needed at the pre pass stage 90 /// \param[in, out] tree The tree to operate on. 91 /// \param[in, out] Indicate of the tree was modified. 92 /// \return Status The status code returned 93 Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override; 94 }; 95 } // namespace dataset 96 } // namespace mindspore 97 98 #endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ 99