• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/common/pass_manager_extends.h"
17 #ifndef _MSC_VER
18 #include <sys/time.h>
19 #endif
20 #include <deque>
21 #include <string>
22 #include <algorithm>
23 #include "ir/anf.h"
24 #include "backend/common/optimizer/cache_manager.h"
25 
26 namespace mindspore {
27 namespace opt {
28 constexpr size_t kMaxRepassTimes = 12;
29 constexpr uint64_t kUSecondInSecond = 1000000;
30 
PassManager(const std::string & name,bool run_only_once)31 PassManager::PassManager(const std::string &name, bool run_only_once)
32     : name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {}
33 
AddPass(const PassPtr & pass)34 void PassManager::AddPass(const PassPtr &pass) {
35   if (pass != nullptr) {
36     passes_.push_back(pass);
37   }
38 }
39 
RunPass(const FuncGraphPtr & func_graph,size_t pass_id,const PassPtr & pass) const40 bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
41   MS_LOG(ERROR) << "stub func";
42   return false;
43 }
44 
GetPassFullname(size_t pass_id,const PassPtr & pass) const45 std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
46   return std::string("hwopt_") + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
47 }
48 
DumpPassIR(const FuncGraphPtr & func_graph,const std::string & pass_fullname) const49 void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const {
50   MS_LOG(ERROR) << "stub func";
51 }
52 
Run(const FuncGraphPtr & func_graph,const std::vector<PassPtr> & passes) const53 bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
54   MS_LOG(ERROR) << "stub func";
55   return false;
56 }
57 
Run(const FuncGraphPtr & func_graph) const58 bool PassManager::Run(const FuncGraphPtr &func_graph) const {
59   MS_LOG(ERROR) << "stub func";
60   return false;
61 }
62 
AddPass(const PassPtr & pass)63 void LitePassManager::AddPass(const PassPtr &pass) {
64   AddPass(pass, {});
65 }
66 
AddPass(const PassPtr & pass,const std::set<std::string> & pass_blacklists)67 void LitePassManager::AddPass(const PassPtr &pass, const std::set<std::string>& pass_blacklists) {
68   if (pass == nullptr) {
69     MS_LOG(ERROR) << "pass is nullptr";
70     return;
71   }
72   if (pass_blacklists.find(pass->name()) != pass_blacklists.end()) {
73     MS_LOG(INFO) << "disable pass: " << pass->name();
74     return;
75   }
76   passes_.push_back(pass);
77 }
78 
RunPass(const FuncGraphPtr & func_graph,size_t pass_id,const PassPtr & pass) const79 bool LitePassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
80   bool changed = false;
81 #if defined(_WIN32) || defined(_WIN64)
82   auto start_time = std::chrono::steady_clock::now();
83 #else
84   struct timeval start_time {};
85   struct timeval end_time {};
86   (void)gettimeofday(&start_time, nullptr);
87 #endif
88   if (pass->Run(func_graph)) {
89     MS_LOG(DEBUG) << "Run pass and find change";
90     changed = true;
91   }
92 #if defined(_WIN32) || defined(_WIN64)
93   auto end_time = std::chrono::steady_clock::now();
94   std::chrono::duration<double, std::ratio<1, kUSecondInSecond>> cost = end_time - start_time;
95   MS_LOG(INFO) << "Run pass " << GetPassFullname(pass_id, pass) << " in " << cost.count() << " us.";
96 #else
97   (void)gettimeofday(&end_time, nullptr);
98   uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
99   cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
100   MS_LOG(INFO) << "Run pass " << GetPassFullname(pass_id, pass) << " in " << cost << " us.";
101 #endif
102   return changed;
103 }
104 
GetPassFullname(size_t pass_id,const PassPtr & pass) const105 std::string LitePassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
106   return "hwopt_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
107 }
108 
Run(const FuncGraphPtr & func_graph,const std::vector<PassPtr> & passes) const109 bool LitePassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
110   if (func_graph == nullptr) {
111     return false;
112   }
113   bool changed = false;
114   size_t num = 0;
115   for (const auto &pass : passes) {
116     if (pass != nullptr) {
117       changed = RunPass(func_graph, num, pass) || changed;
118     } else {
119       MS_LOG(INFO) << "pass is null";
120     }
121     num++;
122   }
123   return changed;
124 }
125 
Run(const FuncGraphPtr & func_graph) const126 bool LitePassManager::Run(const FuncGraphPtr &func_graph) const {
127   if (func_graph == nullptr) {
128     return false;
129   }
130   bool changed = false;
131   size_t count = 0;
132   // run all passes
133   bool change = true;
134   while (change) {
135     change = Run(func_graph, passes_);
136     changed = change || changed;
137     if (run_only_once_ || count > kMaxRepassTimes) {
138       break;
139     }
140     count++;
141     MS_LOG(INFO) << "Run pass counts:" << count;
142   }
143   return changed;
144 }
145 }  // namespace opt
146 }  // namespace mindspore
147