1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/common/pass_manager_extends.h"
17 #ifndef _MSC_VER
18 #include <sys/time.h>
19 #endif
20 #include <deque>
21 #include <string>
22 #include <algorithm>
23 #include "ir/anf.h"
24 #include "backend/common/optimizer/cache_manager.h"
25
26 namespace mindspore {
27 namespace opt {
28 constexpr size_t kMaxRepassTimes = 12;
29 constexpr uint64_t kUSecondInSecond = 1000000;
30
PassManager(const std::string & name,bool run_only_once)31 PassManager::PassManager(const std::string &name, bool run_only_once)
32 : name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {}
33
AddPass(const PassPtr & pass)34 void PassManager::AddPass(const PassPtr &pass) {
35 if (pass != nullptr) {
36 passes_.push_back(pass);
37 }
38 }
39
RunPass(const FuncGraphPtr & func_graph,size_t pass_id,const PassPtr & pass) const40 bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
41 MS_LOG(ERROR) << "stub func";
42 return false;
43 }
44
GetPassFullname(size_t pass_id,const PassPtr & pass) const45 std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
46 return std::string("hwopt_") + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
47 }
48
DumpPassIR(const FuncGraphPtr & func_graph,const std::string & pass_fullname) const49 void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const {
50 MS_LOG(ERROR) << "stub func";
51 }
52
Run(const FuncGraphPtr & func_graph,const std::vector<PassPtr> & passes) const53 bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
54 MS_LOG(ERROR) << "stub func";
55 return false;
56 }
57
Run(const FuncGraphPtr & func_graph) const58 bool PassManager::Run(const FuncGraphPtr &func_graph) const {
59 MS_LOG(ERROR) << "stub func";
60 return false;
61 }
62
AddPass(const PassPtr & pass)63 void LitePassManager::AddPass(const PassPtr &pass) {
64 AddPass(pass, {});
65 }
66
AddPass(const PassPtr & pass,const std::set<std::string> & pass_blacklists)67 void LitePassManager::AddPass(const PassPtr &pass, const std::set<std::string>& pass_blacklists) {
68 if (pass == nullptr) {
69 MS_LOG(ERROR) << "pass is nullptr";
70 return;
71 }
72 if (pass_blacklists.find(pass->name()) != pass_blacklists.end()) {
73 MS_LOG(INFO) << "disable pass: " << pass->name();
74 return;
75 }
76 passes_.push_back(pass);
77 }
78
RunPass(const FuncGraphPtr & func_graph,size_t pass_id,const PassPtr & pass) const79 bool LitePassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
80 bool changed = false;
81 #if defined(_WIN32) || defined(_WIN64)
82 auto start_time = std::chrono::steady_clock::now();
83 #else
84 struct timeval start_time {};
85 struct timeval end_time {};
86 (void)gettimeofday(&start_time, nullptr);
87 #endif
88 if (pass->Run(func_graph)) {
89 MS_LOG(DEBUG) << "Run pass and find change";
90 changed = true;
91 }
92 #if defined(_WIN32) || defined(_WIN64)
93 auto end_time = std::chrono::steady_clock::now();
94 std::chrono::duration<double, std::ratio<1, kUSecondInSecond>> cost = end_time - start_time;
95 MS_LOG(INFO) << "Run pass " << GetPassFullname(pass_id, pass) << " in " << cost.count() << " us.";
96 #else
97 (void)gettimeofday(&end_time, nullptr);
98 uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
99 cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
100 MS_LOG(INFO) << "Run pass " << GetPassFullname(pass_id, pass) << " in " << cost << " us.";
101 #endif
102 return changed;
103 }
104
GetPassFullname(size_t pass_id,const PassPtr & pass) const105 std::string LitePassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
106 return "hwopt_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
107 }
108
Run(const FuncGraphPtr & func_graph,const std::vector<PassPtr> & passes) const109 bool LitePassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
110 if (func_graph == nullptr) {
111 return false;
112 }
113 bool changed = false;
114 size_t num = 0;
115 for (const auto &pass : passes) {
116 if (pass != nullptr) {
117 changed = RunPass(func_graph, num, pass) || changed;
118 } else {
119 MS_LOG(INFO) << "pass is null";
120 }
121 num++;
122 }
123 return changed;
124 }
125
Run(const FuncGraphPtr & func_graph) const126 bool LitePassManager::Run(const FuncGraphPtr &func_graph) const {
127 if (func_graph == nullptr) {
128 return false;
129 }
130 bool changed = false;
131 size_t count = 0;
132 // run all passes
133 bool change = true;
134 while (change) {
135 change = Run(func_graph, passes_);
136 changed = change || changed;
137 if (run_only_once_ || count > kMaxRepassTimes) {
138 break;
139 }
140 count++;
141 MS_LOG(INFO) << "Run pass counts:" << count;
142 }
143 return changed;
144 }
145 } // namespace opt
146 } // namespace mindspore
147