• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
18 #define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
19 
20 #include <memory>
21 #include <vector>
22 #include "minddata/dataset/engine/opt/pass.h"
23 
24 namespace mindspore {
25 namespace dataset {
26 
27 class DatasetOp;
28 
29 /// \class EpochInjectionPass epoch_ctrl_pass.h
30 /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
31 ///     parsing.
32 class EpochCtrlPass : public IRTreePass {
33   /// \class InjectionFinder
34   /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
35   ///     operators that need to be injected.  It is run first by the main injection pass to find out what operators
36   ///     it may need to inject.
37   class InjectionFinder : public IRNodePass {
38    public:
39     /// \brief Constructor
40     explicit InjectionFinder(std::shared_ptr<DatasetNode> node);
41 
42     /// \brief Destructor
43     ~InjectionFinder() = default;
44 
45     /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
46     /// \param[in] node The node being visited
47     /// \param[in, out] modified Indicator if the node was changed at all
48     /// \return Status The status code returned
49     Status Visit(std::shared_ptr<RootNode> node, bool *const modified) override;
50 
51     /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
52     /// \param[in] node The node being visited
53     /// \param[in, out] modified Indicator if the node was changed at all
54     /// \return Status The status code returned
55     Status Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) override;
56 
57 #ifndef ENABLE_ANDROID
58     /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection.
59     /// \param[in] node The node being visited
60     /// \param[in, out] modified Indicator if the node was changed at all
61     /// \return Status The status code returned
62     Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) override;
63 #endif
64 
65     /// \brief Register the TransferNode for further action.
66     /// \param[in] node The node being visited
67     /// \param[in, out] modified Indicator if the node was changed at all
68     /// \return Status The status code returned
69     Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override;
70 
71     /// \brief Getter
injection_point()72     std::shared_ptr<DatasetNode> injection_point() { return injection_point_; }
73 
74     /// \brief Getter
num_epochs()75     int32_t num_epochs() { return num_epochs_; }
76 
77    private:
78     std::shared_ptr<DatasetNode> injection_point_;
79     int32_t num_epochs_;
80   };
81 
82  public:
83   /// \brief Constructor
84   EpochCtrlPass();
85 
86   /// \brief Destructor
87   ~EpochCtrlPass() = default;
88 
89   /// \brief Runs an injection pass to inject in operators needed at the pre pass stage
90   /// \param[in, out] tree The tree to operate on.
91   /// \param[in, out] Indicate of the tree was modified.
92   /// \return Status The status code returned
93   Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
94 };
95 }  // namespace dataset
96 }  // namespace mindspore
97 
98 #endif  // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
99