1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/optimizer/py_pass_manager.h"
17
18 #include <functional>
19 #include <utility>
20
21 #include "ir/manager.h"
22 #include "frontend/optimizer/pass_group.h"
23
24 namespace mindspore {
25 namespace opt {
26 namespace python_pass {
27 PyPassManagerPtr PyPassManager::global_instance = nullptr;
28 std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_;
29
GetPassGroup(Phase phase)30 PassGroupPtr PyPassManager::GetPassGroup(Phase phase) {
31 auto pm = phase_to_group_.find(phase);
32 if (pm == phase_to_group_.end()) {
33 return nullptr;
34 }
35 return pm->second;
36 }
37
GetInstance()38 PyPassManagerPtr PyPassManager::GetInstance() {
39 if (global_instance == nullptr) {
40 global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager());
41 }
42 return global_instance;
43 }
44
PyPassManager()45 PyPassManager::PyPassManager() {
46 phase_to_group_[Phase::PREAD] = std::make_shared<PassGroup>("Pre_AD_PassGroup");
47 phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>("After_OPT_PassGroup");
48 res_ = std::make_shared<MatchResult>();
49 }
50
Register(const std::string & pass_name,const PatternPtr & pattern,const PatternPtr & target,bool requires_grad,bool run_only_once)51 void PyPassManager::Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
52 bool requires_grad, bool run_only_once) {
53 PassGroupPtr cur_pg;
54 if (requires_grad) {
55 cur_pg = GetPassGroup(Phase::PREAD);
56 } else {
57 cur_pg = GetPassGroup(Phase::OPT);
58 }
59 MS_EXCEPTION_IF_NULL(cur_pg);
60 cur_pg->SetRunOnlyOnce(run_only_once);
61 MS_EXCEPTION_IF_NULL(pattern);
62 MS_EXCEPTION_IF_NULL(target);
63 MS_EXCEPTION_IF_NULL(cur_pg);
64 PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once);
65 cur_pg->AddPass(new_pass);
66 }
67
Unregister(const std::string & pass_name)68 void PyPassManager::Unregister(const std::string &pass_name) {
69 auto opt_pm = GetPassGroup(Phase::OPT);
70 if (!opt_pm->DeletePass(pass_name)) {
71 MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
72 }
73 auto pre_ad_pm = GetPassGroup(Phase::PREAD);
74 if (!pre_ad_pm->DeletePass(pass_name)) {
75 MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n";
76 }
77 }
78
GenNewParameter(const PatternPtr & parameter)79 void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
80 MS_EXCEPTION_IF_NULL(parameter);
81 // NOTE: Add NewParameter at early stage will cause CSE problems
82 auto cur_pg = GetPassGroup(Phase::OPT);
83 MS_EXCEPTION_IF_NULL(cur_pg);
84 cur_pg->SetRunOnlyOnce(true);
85 auto new_para_pattern = parameter->cast<NewParameterPtr>();
86 MS_EXCEPTION_IF_NULL(new_para_pattern);
87 auto pass_name = new_para_pattern->para_name();
88 new_para_pattern->set_last(true);
89 auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
90 cur_pg->AddPass(new_pass);
91 }
92
ClearRes()93 void PyPassManager::ClearRes() {
94 MS_LOG(INFO) << "Clear PyPassManager resources!";
95 global_instance = nullptr;
96 phase_to_group_.clear();
97 }
98
99 REGISTER_PYBIND_DEFINE(
__anon21ced1ed0102(const py::module *m) 100 PyPassManager_, ([](const py::module *m) {
101 (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
102 (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
103 .def(py::init([]() { return PyPassManager::GetInstance(); }))
104 .def("register", &PyPassManager::Register, "Register python pass")
105 .def("unregister", &PyPassManager::Unregister, "Unregister Python Pass")
106 .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
107 .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
108 .def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
109 }));
110 } // namespace python_pass
111 } // namespace opt
112 } // namespace mindspore
113