• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
19 #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
20 #include "minddata/dataset/engine/ir/datasetops/root_node.h"
21 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
22 
23 namespace mindspore {
24 namespace dataset {
25 
26 // constructor
InjectionFinder(std::shared_ptr<DatasetNode> node)27 EpochCtrlPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetNode> node)
28     : injection_point_(nullptr), num_epochs_(-1) {}
29 
30 // Performs finder work for BuildVocabOp that has special rules about epoch control injection
Visit(std::shared_ptr<RootNode> node,bool * const modified)31 Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<RootNode> node, bool *const modified) {
32   RETURN_UNEXPECTED_IF_NULL(node);
33   RETURN_UNEXPECTED_IF_NULL(modified);
34   CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
35                                "Invalid data, the node of child should greater than zero.");
36   // The injection is at the child of the root node
37   injection_point_ = node->Children()[0];
38   num_epochs_ = node->num_epochs();
39   return Status::OK();
40 }
41 
42 // Performs finder work for BuildVocabOp that has special rules about epoch control injection
Visit(std::shared_ptr<BuildVocabNode> node,bool * const modified)43 Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
44   RETURN_UNEXPECTED_IF_NULL(node);
45   RETURN_UNEXPECTED_IF_NULL(modified);
46   injection_point_ = nullptr;
47   return Status::OK();
48 }
49 
50 #ifndef ENABLE_ANDROID
51 // Performs finder work for BuildSentencePieceVocabNode that has special rules about epoch control injection
Visit(std::shared_ptr<BuildSentenceVocabNode> node,bool * const modified)52 Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
53   RETURN_UNEXPECTED_IF_NULL(node);
54   RETURN_UNEXPECTED_IF_NULL(modified);
55   injection_point_ = nullptr;
56   return Status::OK();
57 }
58 #endif
59 
VisitAfter(std::shared_ptr<TransferNode> node,bool * const modified)60 Status EpochCtrlPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
61   RETURN_UNEXPECTED_IF_NULL(node);
62   RETURN_UNEXPECTED_IF_NULL(modified);
63   CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
64                                "Invalid data, the node of child should greater than zero.");
65   // Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here.
66   // Move the injection point to the child of this node.
67   injection_point_ = node->Children()[0];
68   return Status::OK();
69 }
70 
71 // constructor
EpochCtrlPass()72 EpochCtrlPass::EpochCtrlPass() {}
73 
74 // Runs an injection pass to inject in operators needed at the pre pass stage
RunOnTree(std::shared_ptr<DatasetNode> root_ir,bool * const modified)75 Status EpochCtrlPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
76   RETURN_UNEXPECTED_IF_NULL(root_ir);
77   RETURN_UNEXPECTED_IF_NULL(modified);
78   MS_LOG(INFO) << "Pre pass: Injection pass started.";
79 
80   // First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
81   // The finder can make updates to the EpochInjectionPass object.
82   EpochCtrlPass::InjectionFinder finder(root_ir);
83   RETURN_IF_NOT_OK(finder.Run(root_ir, modified));
84 
85   // The first injection logic is to check if we should inject the epoch control op as the root node.
86   // Do not inject the op if the number of epochs is 1.
87   std::shared_ptr<DatasetNode> node = finder.injection_point();
88   int32_t num_epochs = finder.num_epochs();
89   if (num_epochs != 1 && node != nullptr) {
90     auto epoch_ctrl_node = std::make_shared<EpochCtrlNode>(num_epochs);
91     RETURN_IF_NOT_OK(node->InsertAbove(epoch_ctrl_node));
92   }
93   MS_LOG(INFO) << "Pre pass: Injection pass complete.";
94   return Status::OK();
95 }
96 }  // namespace dataset
97 }  // namespace mindspore
98