• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore
18import mindspore.nn as nn
19from mindspore import context
20from mindspore.common.tensor import Tensor
21from mindspore.ops import operations as P
22from mindspore.ops import _constants as Constants
23from mindspore.graph_utils.python_pass import register_pass, unregister_pass, set_renorm, gen_new_parameter,\
24    cancel_new_parameter, set_reopt
25from mindspore.common.api import _generate_pip_args
26from mindspore._c_expression import generate_arguments_key, GraphExecutor_
27from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
28
29context.set_context(mode=context.GRAPH_MODE)
30
31def get_func_graph(obj, *args, phase="validate"):
32    args_names, args_list = _generate_pip_args(obj, *args)
33    dic = dict(zip(args_names, args_list))
34    key = generate_arguments_key(dic)
35    obj.arguments_key = str(key)
36    phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
37    _executor = GraphExecutor_.get_instance()
38    _executor.compile(obj, args_list, phase, False, "")
39    return _executor.get_func_graph(phase)
40
41def test_softmax_relu():
42    """
43    Use python pass to transform from Softmax to ReLU.
44    """
45    inputs = Tensor(np.ones([42]), mindspore.float16)
46    softmax_model = nn.Softmax()
47
48    @register_pass(run_only_once=True)
49    def softmax_relu_pass():
50        x = Any()
51        pattern = Call(P.Softmax(), [x])
52        target = Call(P.ReLU(), [x])
53        return pattern, target
54
55    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
56    unregister_pass(softmax_relu_pass)
57    assert "ReLU" in transformed_repr
58    assert "Softmax" not in transformed_repr
59
60def test_prim():
61    inputs = Tensor(np.ones([42]), mindspore.float16)
62    softmax_model = nn.Softmax()
63
64    @register_pass(run_only_once=True)
65    def softmax_relu_pass():
66        x = Any()
67        sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
68        pattern = Call(sigmoid_softmax_pattern, [x])
69        target = Call(P.ReLU(), [x])
70        return pattern, target
71
72    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
73    unregister_pass(softmax_relu_pass)
74    assert "ReLU" in transformed_repr
75    assert "Softmax" not in transformed_repr
76
77def test_softmax_relu_sigmoid():
78    """
79    Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
80
81    NOTE:
82        Sigmoid pattern only exists in the target.
83    """
84    inputs = Tensor(np.ones([42]), mindspore.float16)
85    softmax_model = nn.Softmax()
86
87    @register_pass(run_only_once=True)
88    def softmax_relu_pass():
89        x = Any()
90        softmax_pattern = Prim(P.Softmax())
91        pattern = Call(softmax_pattern, [x])
92        sigmoid_pattern = Prim(P.Sigmoid())
93        call_sigmoid = Call(sigmoid_pattern, [x])
94        relu_pattern = Prim(P.ReLU())
95        target = Call(relu_pattern, [call_sigmoid])
96        return pattern, target
97
98    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
99    unregister_pass(softmax_relu_pass)
100    assert "ReLU" in transformed_repr
101    assert "Sigmoid" in transformed_repr
102    assert "Softmax" not in transformed_repr
103
104
105def test_isin_pattern_0():
106    """
107    Test IsIn pattern which expresses the IsIn/OneOf semantics.
108    """
109    inputs = Tensor(np.ones([42]), mindspore.float16)
110    softmax_model = nn.Softmax()
111
112    @register_pass(run_only_once=True)
113    def softmax_relu_pass():
114        x = Any()
115        softmax_pattern = Prim(P.Softmax())
116        call_softmax = Call(softmax_pattern, [x])
117        relu_pattern = Prim(P.ReLU())
118        call_relu = Call(relu_pattern, [x])
119
120        pattern = OneOf([call_softmax, call_relu])
121        relu6_pattern = Prim(P.ReLU6())
122        target = Call(relu6_pattern, [x])
123        return pattern, target
124    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
125    unregister_pass(softmax_relu_pass)
126    assert "ReLU6" in transformed_repr
127    assert "Softmax" not in transformed_repr
128
129def test_isin_pattern_1():
130    """
131    Test IsIn. IsIn is used as nested inputs for the target in this case.
132    """
133    inputs = Tensor(np.ones([42]), mindspore.float16)
134    softmax_model = nn.Softmax()
135
136    @register_pass(run_only_once=True)
137    def softmax_neg_pass():
138        x = Any()
139        softmax_pattern = Prim(P.Softmax())
140        call_softmax = Call(softmax_pattern, [x])
141        relu_pattern = Prim(P.ReLU())
142        call_relu = Call(relu_pattern, [x])
143
144        pattern = OneOf([call_softmax, call_relu])
145        neg_ops = Prim(P.Neg())
146        target = Call(neg_ops, [pattern])
147        return pattern, target
148    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
149    unregister_pass(softmax_neg_pass)
150    assert "Neg" in transformed_repr
151    assert "Softmax" in transformed_repr
152
153def test_isnot_pattern_0():
154    """
155    Test IsNot pattern which expresses the IsNot semantics.
156    Case: IsNot pass failed to match
157    """
158    set_renorm(False)
159    set_reopt(False)
160    class ConvBN(nn.Cell):
161        def __init__(self):
162            super(ConvBN, self).__init__()
163            self.conv = P.Conv2D(32, 3)
164            self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
165            self.scale = Tensor(np.ones([32]), mindspore.float32)
166            self.bias = Tensor(np.ones([32]), mindspore.float32)
167            self.mean = Tensor(np.ones([32]), mindspore.float32)
168            self.variance = Tensor(np.ones([32]), mindspore.float32)
169            self.bn = P.BatchNorm()
170        def construct(self, x):
171            x = self.conv(x, self.conv_weight)
172            x = self.bn(x, self.scale, self.bias, self.mean, self.variance)
173            return x
174    inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
175    conv_bn_model = ConvBN()
176
177    @register_pass(requires_grad=False, run_only_once=True)
178    def single_bn_pass():
179        """
180        Sub a BN which does NOT take Conv as inputs to ReLU6.
181        """
182        conv2d_prim = Prim("Conv2D")
183        conv2d = Call(conv2d_prim)
184        pattern_0 = NoneOf(conv2d)
185        pattern = Call(P.BatchNorm(), [pattern_0])
186        target = Call(P.ReLU6(), [pattern_0])
187        return pattern, target
188
189    @register_pass(requires_grad=False, run_only_once=True)
190    def bn_pass():
191        """
192        Sub a BN to Softmax.
193        """
194        pattern = Call(P.BatchNorm())
195        target = Call(P.Softmax())
196        return pattern, target
197
198    transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
199    unregister_pass(single_bn_pass)
200    unregister_pass(bn_pass)
201    assert "ReLU6" not in transformed_repr
202    assert "Softmax" in transformed_repr
203    set_renorm(True)
204
205def test_isnot_pattern_1():
206    """
207    Test IsNot pattern which expresses the IsNot semantics.
208    Case: IsNot pattern matches with the graph
209    """
210    inputs = Tensor(np.ones([42]), mindspore.float16)
211    softmax_model = nn.Softmax()
212
213    @register_pass(run_only_once=True)
214    def single_bn_pass():
215        """
216        Sub a BN which does NOT take MatMul as inputs to ReLU6.
217        """
218        matmul = Prim("MatMul")
219        pattern_0 = NoneOf(matmul)
220        softmax = P.Softmax()
221        pattern = Call(softmax, [pattern_0])
222        relu6 = P.ReLU6()
223        target = Call(relu6, [pattern_0])
224        return pattern, target
225
226    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
227    unregister_pass(single_bn_pass)
228    assert "ReLU6" in transformed_repr
229    assert "Softmax" not in transformed_repr
230
231def test_newtensor_pattern():
232    """
233    Test NewTensor pattern in the target
234    """
235    set_renorm(False)
236    set_reopt(False)
237    inputs = Tensor(np.ones([42]), mindspore.float16)
238    softmax_model = nn.Softmax()
239
240    @register_pass(requires_grad=False, run_only_once=True)
241    def softmax_addn_pass():
242        x = Any()
243        pattern = Call(P.Softmax(), [x])
244
245        weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
246        new_weight = NewTensor(weight_tensor)
247        target = Call(P.AddN(), [x, new_weight])
248        return pattern, target
249    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
250    unregister_pass(softmax_addn_pass)
251    assert "AddN" in transformed_repr
252    assert "Softmax" not in transformed_repr
253    set_renorm(True)
254
255def test_newparameter_pattern():
256    """
257    Test NewParameter pattern in the target
258    """
259    inputs = Tensor(np.ones([42]), mindspore.float16)
260    softmax_model = nn.Softmax()
261
262    set_renorm(False)
263    set_reopt(False)
264    @register_pass(requires_grad=False, run_only_once=True)
265    def softmax_addn_pass():
266        x = Any()
267        pattern = Call(P.Softmax(), [x])
268
269        default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
270        default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
271        new_para_0 = NewParameter("Merlin", default_tensor0)
272        new_para_1 = NewParameter("Arthur", default_tensor1)
273        target_0 = Call(P.MatMul(), [new_para_0, new_para_1])
274        target = Call("MakeTuple", [target_0])
275        return pattern, target
276    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
277    unregister_pass(softmax_addn_pass)
278    assert "MatMul" in transformed_repr
279    assert "MakeTuple" in transformed_repr
280    assert "Softmax" not in transformed_repr
281
282def test_imm_target():
283    """
284    Test NewParameter pattern in the target
285    """
286    inputs = Tensor(np.ones([42]), mindspore.float16)
287    softmax_model = nn.Softmax()
288
289    set_renorm(False)
290    set_reopt(False)
291    @register_pass(requires_grad=False, run_only_once=True)
292    def softmax_pass():
293        x = Any()
294        pattern = Call(P.Softmax(), [x])
295        imm = Imm(0)
296        target_0 = Call("MakeTuple", [pattern])
297        target = Call(Constants.kTupleGetItem, [target_0, imm])
298        return pattern, target
299    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
300    unregister_pass(softmax_pass)
301    assert "MakeTuple" in transformed_repr
302    assert Constants.kTupleGetItem in transformed_repr
303    assert "Softmax" in transformed_repr
304
305def test_gen_new_parameter():
306    """
307    Test gen_new_parameter
308    """
309    inputs = Tensor(np.ones([42]), mindspore.float16)
310    softmax_model = nn.Softmax()
311
312    default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
313    new_para = NewParameter("Merlin", default_tensor)
314    set_renorm(False)
315    set_reopt(False)
316    gen_new_parameter(new_para)
317    @register_pass(requires_grad=False, run_only_once=True)
318    def softmax_make_tuple_pass():
319        x = Any()
320        softmax = P.Softmax()
321        pattern = Call(softmax, [x])
322
323        target = Call("MakeTuple", [pattern, new_para])
324        return pattern, target
325    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
326    assert "Merlin" in transformed_repr
327    unregister_pass(softmax_make_tuple_pass)
328    cancel_new_parameter(new_para)
329    transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
330    assert "Merlin" not in transformed_repr
331