1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Python pass register""" 16from inspect import isfunction 17from mindspore.graph_utils.graph_pattern import Pattern, NewParameter 18from mindspore._c_expression import PyPassManager_ 19 20__all__ = [ 21 "register_pass", 22 "unregister_pass", 23 "gen_new_parameter", 24 "cancel_new_parameter", 25 "set_renorm", 26 "set_reopt" 27] 28 29 30class PyPassManager(PyPassManager_): 31 r""" 32 Used to register and unregister python passes which can be used to alter graphs. 33 34 Args: 35 requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True 36 run_only_once (bool): Specify whether or not to run pass only once. Default: False. 37 38 Raises: 39 TypeError: If argument has invalid type. 40 """ 41 def __init__(self, requires_grad=True, run_only_once=False): 42 if not isinstance(requires_grad, bool): 43 raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") 44 if not isinstance(run_only_once, bool): 45 raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}") 46 self.requires_grad = requires_grad 47 self.run_only_once_ = run_only_once 48 PyPassManager_.__init__(self) 49 50 def register(self, py_pass): 51 if not isfunction(py_pass): 52 raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") 53 pattern, target = py_pass() 54 pass_name = py_pass.__name__ 55 if not isinstance(pattern, Pattern): 56 raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") 57 if not isinstance(target, Pattern): 58 raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") 59 super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_) 60 61 def unregister(self, py_pass): 62 if isinstance(py_pass, str): 63 super().unregister(py_pass) 64 return 65 if isfunction(py_pass): 66 super().unregister(py_pass.__name__) 67 return 68 raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") 69 70 def __call__(self, py_pass): 71 self.register(py_pass) 72 return py_pass 73 74 def gen_new_parameter(self, pattern): 75 if not isinstance(pattern, NewParameter): 76 raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") 77 super().gen_new_parameter(pattern) 78 79 def set_renorm(self, should_renorm): 80 if not isinstance(should_renorm, bool): 81 raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") 82 super().set_renorm(should_renorm) 83 84 def set_reopt(self, do_reopt): 85 if not isinstance(do_reopt, bool): 86 raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") 87 super().set_reopt(do_reopt) 88 89 90def register_pass(requires_grad=True, run_only_once=False): 91 """ 92 Register python pass to specified pipeline phase which would be used in compilation. 93 94 Args: 95 requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True. 96 run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: 97 False. 98 99 Returns: 100 This function should be used as a decorator, return the decoratorated pass function. 101 102 Examples: 103 >>> from mindspore.graph_utils.graph_pattern import Call, Any 104 >>> from mindspore.ops import operations as P 105 >>> @register_pass() 106 >>> def toy_pass(): 107 >>> x = Any() 108 >>> pattern = Call(P.Softmax(), [x]) 109 >>> target = Call(P.ReLU(), [x]) 110 >>> return pattern, target 111 """ 112 return PyPassManager(requires_grad, run_only_once) 113 114 115def unregister_pass(py_pass): 116 """ 117 Unregister python pass. 118 119 Args: 120 py_pass(Union(str, function)): target python pass to unregister. 121 """ 122 ppm = PyPassManager() 123 ppm.unregister(py_pass) 124 125 126def gen_new_parameter(pattern): 127 """ 128 Generate specified parameter every time a network gets compiled. 129 130 NOTE: 131 In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without 132 gen_new_parameter, every pass match would build a new Parameter. 133 This would register a pass to add new parameter in the compilation pipeline, so later compilation would 134 ALSO add this parameter unless the pass is unregistered. To unregister this pass, call 135 cancel_new_parameter(pattern) 136 137 Args: 138 pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes 139 after gen_new_parameter. 140 141 Raises: 142 TypeError: If argument has invalid type. 143 144 Examples: 145 >>> from mindspore.graph_utils.graph_pattern import NewParameter 146 >>> abc = NewParameter("abc") 147 >>> gen_new_parameter(abc) 148 """ 149 ppm = PyPassManager() 150 ppm.gen_new_parameter(pattern) 151 152 153def cancel_new_parameter(pattern): 154 """ 155 Use with gen_new_parameter to unregister gen_new_parameter pass. 156 157 Args: 158 pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern 159 describes. 160 161 Examples: 162 >>> from mindspore.graph_utils.graph_pattern import NewParameter 163 >>> abc = NewParameter("abc") 164 >>> gen_new_parameter(abs) 165 >>> # some compilations 166 >>> cancel_new_parameter(abc) 167 """ 168 if not isinstance(pattern, NewParameter): 169 raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") 170 ppm = PyPassManager() 171 ppm.unregister(pattern.para_name) 172 173 174def set_renorm(should_renorm): 175 """ 176 Set whether or not to do renormalization after modified graph in python pass(es). 177 178 Args: 179 should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es). 180 181 NOTE: 182 This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off 183 renormalization may BREAK the network. 184 """ 185 ppm = PyPassManager() 186 ppm.set_renorm(should_renorm) 187 188 189def set_reopt(do_reopt): 190 """ 191 Set whether or not to do optimization after modified graph in python pass(es). 192 193 Args: 194 do_reopt(bool): whether or not to do optimization after modified graph in python pass(es). 195 196 NOTE: 197 This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off 198 renormalization may BREAK the network. 199 """ 200 ppm = PyPassManager() 201 ppm.set_reopt(do_reopt) 202