• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/optimizer/py_pass_manager.h"
17 
18 #include <functional>
19 #include <utility>
20 
21 #include "ir/manager.h"
22 #include "frontend/optimizer/pass_group.h"
23 
24 namespace mindspore {
25 namespace opt {
26 namespace python_pass {
27 PyPassManagerPtr PyPassManager::global_instance = nullptr;
28 std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_;
29 
GetPassGroup(Phase phase)30 PassGroupPtr PyPassManager::GetPassGroup(Phase phase) {
31   auto pm = phase_to_group_.find(phase);
32   if (pm == phase_to_group_.end()) {
33     return nullptr;
34   }
35   return pm->second;
36 }
37 
GetInstance()38 PyPassManagerPtr PyPassManager::GetInstance() {
39   if (global_instance == nullptr) {
40     global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager());
41   }
42   return global_instance;
43 }
44 
PyPassManager()45 PyPassManager::PyPassManager() {
46   phase_to_group_[Phase::PREAD] = std::make_shared<PassGroup>("Pre_AD_PassGroup");
47   phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>("After_OPT_PassGroup");
48   res_ = std::make_shared<MatchResult>();
49 }
50 
Register(const std::string & pass_name,const PatternPtr & pattern,const PatternPtr & target,bool requires_grad,bool run_only_once)51 void PyPassManager::Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
52                              bool requires_grad, bool run_only_once) {
53   PassGroupPtr cur_pg;
54   if (requires_grad) {
55     cur_pg = GetPassGroup(Phase::PREAD);
56   } else {
57     cur_pg = GetPassGroup(Phase::OPT);
58   }
59   MS_EXCEPTION_IF_NULL(cur_pg);
60   cur_pg->SetRunOnlyOnce(run_only_once);
61   MS_EXCEPTION_IF_NULL(pattern);
62   MS_EXCEPTION_IF_NULL(target);
63   MS_EXCEPTION_IF_NULL(cur_pg);
64   PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once);
65   cur_pg->AddPass(new_pass);
66 }
67 
Unregister(const std::string & pass_name)68 void PyPassManager::Unregister(const std::string &pass_name) {
69   auto opt_pm = GetPassGroup(Phase::OPT);
70   if (!opt_pm->DeletePass(pass_name)) {
71     MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
72   }
73   auto pre_ad_pm = GetPassGroup(Phase::PREAD);
74   if (!pre_ad_pm->DeletePass(pass_name)) {
75     MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n";
76   }
77 }
78 
GenNewParameter(const PatternPtr & parameter)79 void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
80   MS_EXCEPTION_IF_NULL(parameter);
81   // NOTE: Add NewParameter at early stage will cause CSE problems
82   auto cur_pg = GetPassGroup(Phase::OPT);
83   MS_EXCEPTION_IF_NULL(cur_pg);
84   cur_pg->SetRunOnlyOnce(true);
85   auto new_para_pattern = parameter->cast<NewParameterPtr>();
86   MS_EXCEPTION_IF_NULL(new_para_pattern);
87   auto pass_name = new_para_pattern->para_name();
88   new_para_pattern->set_last(true);
89   auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
90   cur_pg->AddPass(new_pass);
91 }
92 
ClearRes()93 void PyPassManager::ClearRes() {
94   MS_LOG(INFO) << "Clear PyPassManager resources!";
95   global_instance = nullptr;
96   phase_to_group_.clear();
97 }
98 
99 REGISTER_PYBIND_DEFINE(
__anon21ced1ed0102(const py::module *m) 100   PyPassManager_, ([](const py::module *m) {
101     (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
102     (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
103       .def(py::init([]() { return PyPassManager::GetInstance(); }))
104       .def("register", &PyPassManager::Register, "Register python pass")
105       .def("unregister", &PyPassManager::Unregister, "Unregister Python Pass")
106       .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
107       .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
108       .def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
109   }));
110 }  // namespace python_pass
111 }  // namespace opt
112 }  // namespace mindspore
113