• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Python pass register"""
16from inspect import isfunction
17from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
18from mindspore._c_expression import PyPassManager_
19
20__all__ = [
21    "register_pass",
22    "unregister_pass",
23    "gen_new_parameter",
24    "cancel_new_parameter",
25    "set_renorm",
26    "set_reopt"
27]
28
29
30class PyPassManager(PyPassManager_):
31    r"""
32    Used to register and unregister python passes which can be used to alter graphs.
33
34    Args:
35        requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
36        run_only_once (bool): Specify whether or not to run pass only once. Default: False.
37
38    Raises:
39        TypeError: If argument has invalid type.
40    """
41    def __init__(self, requires_grad=True, run_only_once=False):
42        if not isinstance(requires_grad, bool):
43            raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
44        if not isinstance(run_only_once, bool):
45            raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
46        self.requires_grad = requires_grad
47        self.run_only_once_ = run_only_once
48        PyPassManager_.__init__(self)
49
50    def register(self, py_pass):
51        if not isfunction(py_pass):
52            raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
53        pattern, target = py_pass()
54        pass_name = py_pass.__name__
55        if not isinstance(pattern, Pattern):
56            raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
57        if not isinstance(target, Pattern):
58            raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
59        super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
60
61    def unregister(self, py_pass):
62        if isinstance(py_pass, str):
63            super().unregister(py_pass)
64            return
65        if isfunction(py_pass):
66            super().unregister(py_pass.__name__)
67            return
68        raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
69
70    def __call__(self, py_pass):
71        self.register(py_pass)
72        return py_pass
73
74    def gen_new_parameter(self, pattern):
75        if not isinstance(pattern, NewParameter):
76            raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
77        super().gen_new_parameter(pattern)
78
79    def set_renorm(self, should_renorm):
80        if not isinstance(should_renorm, bool):
81            raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
82        super().set_renorm(should_renorm)
83
84    def set_reopt(self, do_reopt):
85        if not isinstance(do_reopt, bool):
86            raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
87        super().set_reopt(do_reopt)
88
89
90def register_pass(requires_grad=True, run_only_once=False):
91    """
92    Register python pass to specified pipeline phase which would be used in compilation.
93
94    Args:
95        requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True.
96        run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default:
97                             False.
98
99    Returns:
100        This function should be used as a decorator, return the decoratorated pass function.
101
102    Examples:
103        >>> from mindspore.graph_utils.graph_pattern import Call, Any
104        >>> from mindspore.ops import operations as P
105        >>> @register_pass()
106        >>> def toy_pass():
107        >>>     x = Any()
108        >>>     pattern = Call(P.Softmax(), [x])
109        >>>     target = Call(P.ReLU(), [x])
110        >>>     return pattern, target
111    """
112    return PyPassManager(requires_grad, run_only_once)
113
114
115def unregister_pass(py_pass):
116    """
117    Unregister python pass.
118
119    Args:
120        py_pass(Union(str, function)): target python pass to unregister.
121    """
122    ppm = PyPassManager()
123    ppm.unregister(py_pass)
124
125
126def gen_new_parameter(pattern):
127    """
128    Generate specified parameter every time a network gets compiled.
129
130    NOTE:
131        In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
132        gen_new_parameter, every pass match would build a new Parameter.
133        This would register a pass to add new parameter in the compilation pipeline, so later compilation would
134        ALSO add this parameter unless the pass is unregistered. To unregister this pass, call
135        cancel_new_parameter(pattern)
136
137    Args:
138        pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes
139            after gen_new_parameter.
140
141    Raises:
142        TypeError: If argument has invalid type.
143
144    Examples:
145        >>> from mindspore.graph_utils.graph_pattern import NewParameter
146        >>> abc = NewParameter("abc")
147        >>> gen_new_parameter(abc)
148    """
149    ppm = PyPassManager()
150    ppm.gen_new_parameter(pattern)
151
152
153def cancel_new_parameter(pattern):
154    """
155    Use with gen_new_parameter to unregister gen_new_parameter pass.
156
157    Args:
158        pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
159            describes.
160
161    Examples:
162        >>> from mindspore.graph_utils.graph_pattern import NewParameter
163        >>> abc = NewParameter("abc")
164        >>> gen_new_parameter(abs)
165        >>> # some compilations
166        >>> cancel_new_parameter(abc)
167    """
168    if not isinstance(pattern, NewParameter):
169        raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
170    ppm = PyPassManager()
171    ppm.unregister(pattern.para_name)
172
173
174def set_renorm(should_renorm):
175    """
176    Set whether or not to do renormalization after modified graph in python pass(es).
177
178    Args:
179        should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es).
180
181    NOTE:
182        This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
183        renormalization may BREAK the network.
184    """
185    ppm = PyPassManager()
186    ppm.set_renorm(should_renorm)
187
188
189def set_reopt(do_reopt):
190    """
191    Set whether or not to do optimization after modified graph in python pass(es).
192
193    Args:
194        do_reopt(bool): whether or not to do optimization after modified graph in python pass(es).
195
196    NOTE:
197        This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
198        renormalization may BREAK the network.
199    """
200    ppm = PyPassManager()
201    ppm.set_reopt(do_reopt)
202