1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore 18import mindspore.nn as nn 19from mindspore import context 20from mindspore.common.tensor import Tensor 21from mindspore.ops import operations as P 22from mindspore.ops import _constants as Constants 23from mindspore.graph_utils.python_pass import register_pass, unregister_pass, set_renorm, gen_new_parameter,\ 24 cancel_new_parameter, set_reopt 25from mindspore.common.api import _generate_pip_args 26from mindspore._c_expression import generate_arguments_key, GraphExecutor_ 27from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm 28 29context.set_context(mode=context.GRAPH_MODE) 30 31def get_func_graph(obj, *args, phase="validate"): 32 args_names, args_list = _generate_pip_args(obj, *args) 33 dic = dict(zip(args_names, args_list)) 34 key = generate_arguments_key(dic) 35 obj.arguments_key = str(key) 36 phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key 37 _executor = GraphExecutor_.get_instance() 38 _executor.compile(obj, args_list, phase, False, "") 39 return _executor.get_func_graph(phase) 40 41def test_softmax_relu(): 42 """ 43 Use python pass to transform from Softmax to ReLU. 44 """ 45 inputs = Tensor(np.ones([42]), mindspore.float16) 46 softmax_model = nn.Softmax() 47 48 @register_pass(run_only_once=True) 49 def softmax_relu_pass(): 50 x = Any() 51 pattern = Call(P.Softmax(), [x]) 52 target = Call(P.ReLU(), [x]) 53 return pattern, target 54 55 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) 56 unregister_pass(softmax_relu_pass) 57 assert "ReLU" in transformed_repr 58 assert "Softmax" not in transformed_repr 59 60def test_prim(): 61 inputs = Tensor(np.ones([42]), mindspore.float16) 62 softmax_model = nn.Softmax() 63 64 @register_pass(run_only_once=True) 65 def softmax_relu_pass(): 66 x = Any() 67 sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()]) 68 pattern = Call(sigmoid_softmax_pattern, [x]) 69 target = Call(P.ReLU(), [x]) 70 return pattern, target 71 72 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) 73 unregister_pass(softmax_relu_pass) 74 assert "ReLU" in transformed_repr 75 assert "Softmax" not in transformed_repr 76 77def test_softmax_relu_sigmoid(): 78 """ 79 Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)). 80 81 NOTE: 82 Sigmoid pattern only exists in the target. 83 """ 84 inputs = Tensor(np.ones([42]), mindspore.float16) 85 softmax_model = nn.Softmax() 86 87 @register_pass(run_only_once=True) 88 def softmax_relu_pass(): 89 x = Any() 90 softmax_pattern = Prim(P.Softmax()) 91 pattern = Call(softmax_pattern, [x]) 92 sigmoid_pattern = Prim(P.Sigmoid()) 93 call_sigmoid = Call(sigmoid_pattern, [x]) 94 relu_pattern = Prim(P.ReLU()) 95 target = Call(relu_pattern, [call_sigmoid]) 96 return pattern, target 97 98 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) 99 unregister_pass(softmax_relu_pass) 100 assert "ReLU" in transformed_repr 101 assert "Sigmoid" in transformed_repr 102 assert "Softmax" not in transformed_repr 103 104 105def test_isin_pattern_0(): 106 """ 107 Test IsIn pattern which expresses the IsIn/OneOf semantics. 108 """ 109 inputs = Tensor(np.ones([42]), mindspore.float16) 110 softmax_model = nn.Softmax() 111 112 @register_pass(run_only_once=True) 113 def softmax_relu_pass(): 114 x = Any() 115 softmax_pattern = Prim(P.Softmax()) 116 call_softmax = Call(softmax_pattern, [x]) 117 relu_pattern = Prim(P.ReLU()) 118 call_relu = Call(relu_pattern, [x]) 119 120 pattern = OneOf([call_softmax, call_relu]) 121 relu6_pattern = Prim(P.ReLU6()) 122 target = Call(relu6_pattern, [x]) 123 return pattern, target 124 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) 125 unregister_pass(softmax_relu_pass) 126 assert "ReLU6" in transformed_repr 127 assert "Softmax" not in transformed_repr 128 129def test_isin_pattern_1(): 130 """ 131 Test IsIn. IsIn is used as nested inputs for the target in this case. 132 """ 133 inputs = Tensor(np.ones([42]), mindspore.float16) 134 softmax_model = nn.Softmax() 135 136 @register_pass(run_only_once=True) 137 def softmax_neg_pass(): 138 x = Any() 139 softmax_pattern = Prim(P.Softmax()) 140 call_softmax = Call(softmax_pattern, [x]) 141 relu_pattern = Prim(P.ReLU()) 142 call_relu = Call(relu_pattern, [x]) 143 144 pattern = OneOf([call_softmax, call_relu]) 145 neg_ops = Prim(P.Neg()) 146 target = Call(neg_ops, [pattern]) 147 return pattern, target 148 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) 149 unregister_pass(softmax_neg_pass) 150 assert "Neg" in transformed_repr 151 assert "Softmax" in transformed_repr 152 153def test_isnot_pattern_0(): 154 """ 155 Test IsNot pattern which expresses the IsNot semantics. 156 Case: IsNot pass failed to match 157 """ 158 set_renorm(False) 159 set_reopt(False) 160 class ConvBN(nn.Cell): 161 def __init__(self): 162 super(ConvBN, self).__init__() 163 self.conv = P.Conv2D(32, 3) 164 self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32) 165 self.scale = Tensor(np.ones([32]), mindspore.float32) 166 self.bias = Tensor(np.ones([32]), mindspore.float32) 167 self.mean = Tensor(np.ones([32]), mindspore.float32) 168 self.variance = Tensor(np.ones([32]), mindspore.float32) 169 self.bn = P.BatchNorm() 170 def construct(self, x): 171 x = self.conv(x, self.conv_weight) 172 x = self.bn(x, self.scale, self.bias, self.mean, self.variance) 173 return x 174 inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32) 175 conv_bn_model = ConvBN() 176 177 @register_pass(requires_grad=False, run_only_once=True) 178 def single_bn_pass(): 179 """ 180 Sub a BN which does NOT take Conv as inputs to ReLU6. 181 """ 182 conv2d_prim = Prim("Conv2D") 183 conv2d = Call(conv2d_prim) 184 pattern_0 = NoneOf(conv2d) 185 pattern = Call(P.BatchNorm(), [pattern_0]) 186 target = Call(P.ReLU6(), [pattern_0]) 187 return pattern, target 188 189 @register_pass(requires_grad=False, run_only_once=True) 190 def bn_pass(): 191 """ 192 Sub a BN to Softmax. 193 """ 194 pattern = Call(P.BatchNorm()) 195 target = Call(P.Softmax()) 196 return pattern, target 197 198 transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) 199 unregister_pass(single_bn_pass) 200 unregister_pass(bn_pass) 201 assert "ReLU6" not in transformed_repr 202 assert "Softmax" in transformed_repr 203 set_renorm(True) 204 205def test_isnot_pattern_1(): 206 """ 207 Test IsNot pattern which expresses the IsNot semantics. 208 Case: IsNot pattern matches with the graph 209 """ 210 inputs = Tensor(np.ones([42]), mindspore.float16) 211 softmax_model = nn.Softmax() 212 213 @register_pass(run_only_once=True) 214 def single_bn_pass(): 215 """ 216 Sub a BN which does NOT take MatMul as inputs to ReLU6. 217 """ 218 matmul = Prim("MatMul") 219 pattern_0 = NoneOf(matmul) 220 softmax = P.Softmax() 221 pattern = Call(softmax, [pattern_0]) 222 relu6 = P.ReLU6() 223 target = Call(relu6, [pattern_0]) 224 return pattern, target 225 226 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) 227 unregister_pass(single_bn_pass) 228 assert "ReLU6" in transformed_repr 229 assert "Softmax" not in transformed_repr 230 231def test_newtensor_pattern(): 232 """ 233 Test NewTensor pattern in the target 234 """ 235 set_renorm(False) 236 set_reopt(False) 237 inputs = Tensor(np.ones([42]), mindspore.float16) 238 softmax_model = nn.Softmax() 239 240 @register_pass(requires_grad=False, run_only_once=True) 241 def softmax_addn_pass(): 242 x = Any() 243 pattern = Call(P.Softmax(), [x]) 244 245 weight_tensor = Tensor(np.zeros([42]), mindspore.float16) 246 new_weight = NewTensor(weight_tensor) 247 target = Call(P.AddN(), [x, new_weight]) 248 return pattern, target 249 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) 250 unregister_pass(softmax_addn_pass) 251 assert "AddN" in transformed_repr 252 assert "Softmax" not in transformed_repr 253 set_renorm(True) 254 255def test_newparameter_pattern(): 256 """ 257 Test NewParameter pattern in the target 258 """ 259 inputs = Tensor(np.ones([42]), mindspore.float16) 260 softmax_model = nn.Softmax() 261 262 set_renorm(False) 263 set_reopt(False) 264 @register_pass(requires_grad=False, run_only_once=True) 265 def softmax_addn_pass(): 266 x = Any() 267 pattern = Call(P.Softmax(), [x]) 268 269 default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) 270 default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) 271 new_para_0 = NewParameter("Merlin", default_tensor0) 272 new_para_1 = NewParameter("Arthur", default_tensor1) 273 target_0 = Call(P.MatMul(), [new_para_0, new_para_1]) 274 target = Call("MakeTuple", [target_0]) 275 return pattern, target 276 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) 277 unregister_pass(softmax_addn_pass) 278 assert "MatMul" in transformed_repr 279 assert "MakeTuple" in transformed_repr 280 assert "Softmax" not in transformed_repr 281 282def test_imm_target(): 283 """ 284 Test NewParameter pattern in the target 285 """ 286 inputs = Tensor(np.ones([42]), mindspore.float16) 287 softmax_model = nn.Softmax() 288 289 set_renorm(False) 290 set_reopt(False) 291 @register_pass(requires_grad=False, run_only_once=True) 292 def softmax_pass(): 293 x = Any() 294 pattern = Call(P.Softmax(), [x]) 295 imm = Imm(0) 296 target_0 = Call("MakeTuple", [pattern]) 297 target = Call(Constants.kTupleGetItem, [target_0, imm]) 298 return pattern, target 299 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) 300 unregister_pass(softmax_pass) 301 assert "MakeTuple" in transformed_repr 302 assert Constants.kTupleGetItem in transformed_repr 303 assert "Softmax" in transformed_repr 304 305def test_gen_new_parameter(): 306 """ 307 Test gen_new_parameter 308 """ 309 inputs = Tensor(np.ones([42]), mindspore.float16) 310 softmax_model = nn.Softmax() 311 312 default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) 313 new_para = NewParameter("Merlin", default_tensor) 314 set_renorm(False) 315 set_reopt(False) 316 gen_new_parameter(new_para) 317 @register_pass(requires_grad=False, run_only_once=True) 318 def softmax_make_tuple_pass(): 319 x = Any() 320 softmax = P.Softmax() 321 pattern = Call(softmax, [x]) 322 323 target = Call("MakeTuple", [pattern, new_para]) 324 return pattern, target 325 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) 326 assert "Merlin" in transformed_repr 327 unregister_pass(softmax_make_tuple_pass) 328 cancel_new_parameter(new_para) 329 transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) 330 assert "Merlin" not in transformed_repr 331