1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
19 #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
20 #include "minddata/dataset/engine/ir/datasetops/root_node.h"
21 #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
22
23 namespace mindspore {
24 namespace dataset {
25
26 // constructor
InjectionFinder(std::shared_ptr<DatasetNode> node)27 EpochCtrlPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetNode> node)
28 : injection_point_(nullptr), num_epochs_(-1) {}
29
30 // Performs finder work for BuildVocabOp that has special rules about epoch control injection
Visit(std::shared_ptr<RootNode> node,bool * const modified)31 Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<RootNode> node, bool *const modified) {
32 RETURN_UNEXPECTED_IF_NULL(node);
33 RETURN_UNEXPECTED_IF_NULL(modified);
34 CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
35 "Invalid data, the node of child should greater than zero.");
36 // The injection is at the child of the root node
37 injection_point_ = node->Children()[0];
38 num_epochs_ = node->num_epochs();
39 return Status::OK();
40 }
41
42 // Performs finder work for BuildVocabOp that has special rules about epoch control injection
Visit(std::shared_ptr<BuildVocabNode> node,bool * const modified)43 Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
44 RETURN_UNEXPECTED_IF_NULL(node);
45 RETURN_UNEXPECTED_IF_NULL(modified);
46 injection_point_ = nullptr;
47 return Status::OK();
48 }
49
50 #ifndef ENABLE_ANDROID
51 // Performs finder work for BuildSentencePieceVocabNode that has special rules about epoch control injection
Visit(std::shared_ptr<BuildSentenceVocabNode> node,bool * const modified)52 Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
53 RETURN_UNEXPECTED_IF_NULL(node);
54 RETURN_UNEXPECTED_IF_NULL(modified);
55 injection_point_ = nullptr;
56 return Status::OK();
57 }
58 #endif
59
VisitAfter(std::shared_ptr<TransferNode> node,bool * const modified)60 Status EpochCtrlPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
61 RETURN_UNEXPECTED_IF_NULL(node);
62 RETURN_UNEXPECTED_IF_NULL(modified);
63 CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
64 "Invalid data, the node of child should greater than zero.");
65 // Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here.
66 // Move the injection point to the child of this node.
67 injection_point_ = node->Children()[0];
68 return Status::OK();
69 }
70
71 // constructor
EpochCtrlPass()72 EpochCtrlPass::EpochCtrlPass() {}
73
74 // Runs an injection pass to inject in operators needed at the pre pass stage
RunOnTree(std::shared_ptr<DatasetNode> root_ir,bool * const modified)75 Status EpochCtrlPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
76 RETURN_UNEXPECTED_IF_NULL(root_ir);
77 RETURN_UNEXPECTED_IF_NULL(modified);
78 MS_LOG(INFO) << "Pre pass: Injection pass started.";
79
80 // First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
81 // The finder can make updates to the EpochInjectionPass object.
82 EpochCtrlPass::InjectionFinder finder(root_ir);
83 RETURN_IF_NOT_OK(finder.Run(root_ir, modified));
84
85 // The first injection logic is to check if we should inject the epoch control op as the root node.
86 // Do not inject the op if the number of epochs is 1.
87 std::shared_ptr<DatasetNode> node = finder.injection_point();
88 int32_t num_epochs = finder.num_epochs();
89 if (num_epochs != 1 && node != nullptr) {
90 auto epoch_ctrl_node = std::make_shared<EpochCtrlNode>(num_epochs);
91 RETURN_IF_NOT_OK(node->InsertAbove(epoch_ctrl_node));
92 }
93 MS_LOG(INFO) << "Pre pass: Injection pass complete.";
94 return Status::OK();
95 }
96 } // namespace dataset
97 } // namespace mindspore
98