1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_pynative_hook_grad """ 16import numpy as np 17import pytest 18import mindspore.nn as nn 19import mindspore.ops.operations as P 20from mindspore.nn import Cell 21from mindspore import context 22from mindspore.common.tensor import Tensor 23from mindspore.ops.composite import GradOperation 24from mindspore.common import ParameterTuple 25 26class MetaFactory: 27 def __init__(self): 28 self.device_target = context.get_context('device_target') 29 self.rank_size = None 30 self.device_id = None 31 self.global_rank_id = None 32 33class HookBase(MetaFactory): 34 def __init__(self): 35 super().__init__() 36 MetaFactory.__init__(self) 37 self.grad_input_list = [] 38 self.grad_output_list = [] 39 40 def ms_record_hook(self, cell_id, grad_input, grad_output): 41 for grad in grad_input: 42 self.grad_input_list.append(grad) 43 for grad in grad_output: 44 self.grad_output_list.append(grad) 45 46 def ms_change_grad_double_hook(self, cell_id, grad_input, grad_output): 47 y = Tensor(np.array([2.0]).astype(np.float32)) 48 mul = P.Mul() 49 grad = grad_output[0] 50 output = mul(grad, y) 51 return output 52 53class FinalNet(nn.Cell, HookBase): 54 def __init__(self): 55 super().__init__() 56 HookBase.__init__(self) 57 self.conv = nn.Conv2d(1, 3, 3) 58 self.relu = nn.ReLU() 59 60 def construct(self, x, flag): 61 if flag: 62 x = self.conv(x) 63 else: 64 x = self.relu(x) 65 return self.relu(x) 66 67class _Grad(Cell): 68 def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): 69 super().__init__() 70 self.network = network 71 self.grad = grad 72 self.sens_param = self.grad.sens_param 73 self.wrt_params = wrt_params 74 self.real_inputs_count = real_inputs_count 75 if self.wrt_params: 76 self.params = ParameterTuple(self.network.trainable_params()) 77 78 def construct(self, *inputs): 79 if self.wrt_params: 80 if self.real_inputs_count is None or self.sens_param is False: 81 return self.grad(self.network, self.params)(*inputs) 82 real_inputs = inputs[:self.real_inputs_count] 83 sense_param_inputs = inputs[self.real_inputs_count:] 84 return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) 85 if self.real_inputs_count is None or self.sens_param is False: 86 return self.grad(self.network)(*inputs) 87 real_inputs = inputs[:self.real_inputs_count] 88 sense_param_inputs = inputs[self.real_inputs_count:] 89 return self.grad(self.network)(*real_inputs, sense_param_inputs) 90 91class GradOfAllInputs(_Grad): 92 def __init__(self, network, sens_param=True, real_inputs_count=None): 93 super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param), 94 network=network, real_inputs_count=real_inputs_count) 95 96class MsMul4(nn.Cell): 97 def construct(self, input_mul): 98 out = input_mul * 2 99 return out 100 101class MsMul(nn.Cell): 102 def __init__(self): 103 super().__init__() 104 self.mul = P.Mul() 105 106 def construct(self, x, y): 107 x = self.mul(x, y) 108 return x 109 110class MsAdd4(nn.Cell): 111 def construct(self, input_add): 112 out = input_add + 4 113 return out 114 115class MsOneInputNet(nn.Cell, HookBase): 116 def __init__(self): 117 super().__init__() 118 HookBase.__init__(self) 119 self.add = MsAdd4() 120 self.mul = MsMul4() 121 self.relu = nn.ReLU() 122 123 def construct(self, x): 124 x = self.add(x) 125 x = self.mul(x) 126 out = self.relu(x) 127 return out 128 129class MsMultiInputNet(nn.Cell, HookBase): 130 def __init__(self): 131 super().__init__() 132 HookBase.__init__(self) 133 self.mul1 = MsMul() 134 self.mul2 = MsMul4() 135 def construct(self, x, y): 136 a = self.mul1(x, y) 137 b = self.mul2(x) 138 output = self.mul1(a, b) 139 return output 140 141class MsNetWithParameter(nn.Cell, HookBase): 142 def __init__(self): 143 super().__init__() 144 HookBase.__init__(self) 145 self.conv1 = nn.Conv2d(2, 4, kernel_size=(1, 1), has_bias=True, 146 weight_init=Tensor(np.ones([4, 2, 1, 1]).astype(np.float32)), 147 bias_init=Tensor(np.ones([4]).astype(np.float32))) 148 self.conv2 = nn.Conv2d(4, 8, kernel_size=(1, 1), has_bias=True, 149 weight_init=Tensor(np.ones([8, 4, 1, 1]).astype(np.float32)), 150 bias_init=Tensor(np.ones([8]).astype(np.float32))) 151 152 def construct(self, x): 153 x = self.conv1(x) 154 output = self.conv2(x) 155 return output 156 157class MsNetWithCellinCell(nn.Cell, HookBase): 158 def __init__(self): 159 super().__init__() 160 HookBase.__init__(self) 161 self.net1 = MsOneInputNet() 162 self.mul = MsMul4() 163 164 def construct(self, x): 165 x = self.net1(x) 166 output = self.mul(x) 167 return output 168 169class MsSingleOpNetWithBprop(nn.Cell, HookBase): 170 def __init__(self): 171 super().__init__() 172 HookBase.__init__(self) 173 self.op = nn.ReLU() 174 175 def construct(self, x): 176 return self.op(x) 177 178 def bprop(self, x, out, dout): 179 y = Tensor(np.array([5.0]).astype(np.float32)) 180 mul = P.Mul() 181 return mul(x, y) 182 183class MsNetHasBpropInChild(nn.Cell, HookBase): 184 def __init__(self): 185 super().__init__() 186 HookBase.__init__(self) 187 self.add = MsAdd4() 188 self.bprop_net = MsSingleOpNetWithBprop() 189 190 def construct(self, x): 191 x = self.add(x) 192 return self.bprop_net(x) 193 194class MsMultiOpNetWithBprop(nn.Cell, HookBase): 195 def __init__(self): 196 super().__init__() 197 HookBase.__init__(self) 198 self.mul = MsMul4() 199 self.relu = nn.ReLU() 200 201 def construct(self, x): 202 x = self.mul(x) 203 return self.relu(x) 204 205 def bprop(self, x, out, dout): 206 y = Tensor(np.array([5.0]).astype(np.float32)) 207 mul = P.Mul() 208 return mul(x, y) 209 210def _count_unequal_element(data_expected, data_me, rtol, atol): 211 assert data_expected.shape == data_me.shape 212 total_count = len(data_expected.flatten()) 213 error = np.abs(data_expected - data_me) 214 greater = np.greater(error, atol + np.abs(data_me)*rtol) 215 loss_count = np.count_nonzero(greater) 216 assert (loss_count/total_count) < rtol,\ 217 "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\ 218 format(data_expected[greater], data_me[greater], error[greater]) 219 220def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): 221 if np.any(np.isnan(data_expected)): 222 assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) 223 elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): 224 _count_unequal_element(data_expected, data_me, rtol, atol) 225 else: 226 assert True 227 228def pynative_hook_diff_hook(): 229 input_np = np.ones([1, 1, 224, 224]).astype(np.float32) 230 ms_net = FinalNet() 231 ms_net.set_grad() 232 ms_net.conv.register_backward_hook(ms_net.ms_record_hook) 233 ms_net.relu.register_backward_hook(ms_net.ms_change_grad_double_hook) 234 input_ms = Tensor(input_np) 235 out_ms = ms_net(input_ms, Tensor(1)) 236 grad_net = GradOfAllInputs(ms_net) 237 grad_net.set_train() 238 grad_net(input_ms, Tensor(1), out_ms) 239 240def pynative_hook_outermost_cell_not_change_grad(): 241 input_np = np.ones([2, 2]).astype(np.float32) 242 243 ms_net = MsOneInputNet() 244 ms_net.set_grad() 245 ms_net.register_backward_hook(ms_net.ms_record_hook) 246 input_ms = Tensor(input_np) 247 out_ms = ms_net(input_ms) 248 grad_net = GradOfAllInputs(ms_net) 249 grad_net.set_train() 250 input_ms_grad = grad_net(input_ms, out_ms) 251 252 #input grad 253 input_torch_grad = np.array([[20, 20], [20, 20]]) 254 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 255 #hook record grad 256 torch_net_grad_output = np.array([[10, 10], [10, 10]]) 257 torch_net_grad_input = np.array([[20, 20], [20, 20]]) 258 allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001) 259 allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001) 260 261def pynative_hook_all_cell_record_grad(): 262 input_np = np.ones([2, 2]).astype(np.float32) 263 264 ms_net = MsOneInputNet() 265 ms_net.set_grad() 266 ms_net.mul.register_backward_hook(ms_net.ms_record_hook) 267 ms_net.add.register_backward_hook(ms_net.ms_record_hook) 268 ms_net.relu.register_backward_hook(ms_net.ms_record_hook) 269 input_ms = Tensor(input_np) 270 out_ms = ms_net(input_ms) 271 grad_net = GradOfAllInputs(ms_net) 272 grad_net.set_train() 273 grad_net(input_ms, out_ms) 274 275 torch_net_grad_input0 = np.array([[10, 10], [10, 10]]) 276 torch_net_grad_output0 = np.array([[10, 10], [10, 10]]) 277 torch_net_grad_input1 = np.array([[20, 20], [20, 20]]) 278 torch_net_grad_output1 = np.array([[10, 10], [10, 10]]) 279 allclose_nparray(torch_net_grad_input0, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001) 280 allclose_nparray(torch_net_grad_output0, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001) 281 allclose_nparray(torch_net_grad_input1, ms_net.grad_output_list[1].asnumpy(), 0.001, 0.001) 282 allclose_nparray(torch_net_grad_output1, ms_net.grad_input_list[1].asnumpy(), 0.001, 0.001) 283 284 torch_net_grad_input3 = np.array([[20, 20], [20, 20]]) 285 torch_net_grad_output2 = np.array([[20, 20], [20, 20]]) 286 allclose_nparray(torch_net_grad_input3, ms_net.grad_output_list[2].asnumpy(), 0.001, 0.001) 287 allclose_nparray(torch_net_grad_output2, ms_net.grad_input_list[2].asnumpy(), 0.001, 0.001) 288 289def pynative_hook_mul_change_input_grad(): 290 input_np = np.ones([2, 2]).astype(np.float32) 291 292 ms_net = MsOneInputNet() 293 ms_net.set_grad() 294 ms_net.mul.register_backward_hook(ms_net.ms_change_grad_double_hook) 295 input_ms = Tensor(input_np) 296 out_ms = ms_net(input_ms) 297 grad_net = GradOfAllInputs(ms_net) 298 grad_net.set_train() 299 input_ms_grad = grad_net(input_ms, out_ms) 300 301 #input grad 302 input_torch_grad = np.array([[40, 40], [40, 40]]) 303 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 304 305def pynative_hook_mul2_change_input_grad(): 306 input1_np = np.array([2.0, 3.0, 4.0]).astype(np.float32) 307 input2_np = np.array([2.0, 3.0, 4.0]).astype(np.float32) 308 309 ms_net = MsMultiInputNet() 310 ms_net.set_grad() 311 ms_net.mul2.register_backward_hook(ms_net.ms_change_grad_double_hook) 312 input1_ms = Tensor(input1_np) 313 input2_ms = Tensor(input2_np) 314 out_ms = ms_net(input1_ms, input2_ms) 315 grad_net = GradOfAllInputs(ms_net) 316 grad_net.set_train() 317 input_ms_grad = grad_net(input1_ms, input2_ms, out_ms) 318 319 #input grad 320 input1_torch_grad = np.array([384, 2916, 12288]) 321 input2_torch_grad = np.array([128, 972, 4096]) 322 allclose_nparray(input1_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 323 allclose_nparray(input2_torch_grad, input_ms_grad[1].asnumpy(), 0.001, 0.001) 324 325def pynative_hook_outermost_cell_change_grad(): 326 input_np = np.ones([2, 2]).astype(np.float32) 327 328 ms_net = MsNetWithCellinCell() 329 ms_net.set_grad() 330 ms_net.register_backward_hook(ms_net.ms_change_grad_double_hook) 331 input_ms = Tensor(input_np) 332 out_ms = ms_net(input_ms) 333 grad_net = GradOfAllInputs(ms_net) 334 grad_net.set_train() 335 input_ms_grad = grad_net(input_ms, out_ms) 336 337 #input grad 338 out_torch = np.array([[20, 20], [20, 20]]) 339 input_torch_grad = np.array([[160, 160], [160, 160]]) 340 allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001) 341 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 342 343def pynative_hook_outermost_cell_record_grad(): 344 input_np = np.ones([2, 2]).astype(np.float32) 345 346 ms_net = MsSingleOpNetWithBprop() 347 ms_net.set_grad() 348 ms_net.bprop_debug = True 349 ms_net.register_backward_hook(ms_net.ms_record_hook) 350 input_ms = Tensor(input_np) 351 out_ms = ms_net(input_ms) 352 grad_net = GradOfAllInputs(ms_net) 353 grad_net.set_train() 354 input_ms_grad = grad_net(input_ms, out_ms) 355 356 if ms_net.grad_output_list or ms_net.grad_input_list: 357 assert False 358 359 #input grad 360 out_torch = np.array([[1, 1], [1, 1]]) 361 input_torch_grad = np.array([[5, 5], [5, 5]]) 362 allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001) 363 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 364 365def pynative_hook_bprop_outermost_cell_record_grad(): 366 input_np = np.ones([2, 2]).astype(np.float32) 367 368 ms_net = MsNetHasBpropInChild() 369 ms_net.set_grad() 370 ms_net.bprop_net.bprop_debug = True 371 ms_net.register_backward_hook(ms_net.ms_record_hook) 372 input_ms = Tensor(input_np) 373 out_ms = ms_net(input_ms) 374 grad_net = GradOfAllInputs(ms_net) 375 grad_net.set_train() 376 input_ms_grad = grad_net(input_ms, out_ms) 377 378 if len(ms_net.grad_output_list) != len(ms_net.grad_input_list) or not ms_net.grad_output_list: 379 assert False 380 381 #input grad 382 out_torch = np.array([[5, 5], [5, 5]]) 383 input_torch_grad = np.array([[25, 25], [25, 25]]) 384 allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001) 385 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 386 #hook record grad 387 torch_net_grad_output = np.array([[5, 5], [5, 5]]) 388 torch_net_grad_input = np.array([[25, 25], [25, 25]]) 389 allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001) 390 allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001) 391 392def pynative_hook_child_cell_record_grad(): 393 input_np = np.ones([2, 2]).astype(np.float32) 394 395 ms_net = MsMultiOpNetWithBprop() 396 ms_net.set_grad() 397 ms_net.bprop_debug = True 398 ms_net.relu.register_backward_hook(ms_net.ms_record_hook) 399 ms_net.mul.register_backward_hook(ms_net.ms_record_hook) 400 input_ms = Tensor(input_np) 401 out_ms = ms_net(input_ms) 402 grad_net = GradOfAllInputs(ms_net) 403 grad_net.set_train() 404 grad_net(input_ms, out_ms) 405 406 if ms_net.grad_output_list or ms_net.grad_input_list: 407 assert False 408 409@pytest.mark.level1 410@pytest.mark.platform_arm_ascend_training 411@pytest.mark.platform_x86_ascend_training 412@pytest.mark.env_onecard 413def test_pynative_hook_diff_hook_ascend(): 414 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 415 pynative_hook_diff_hook() 416 417@pytest.mark.level0 418@pytest.mark.platform_x86_gpu_training 419@pytest.mark.env_onecard 420def test_pynative_hook_diff_hook_gpu(): 421 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 422 pynative_hook_diff_hook() 423 424@pytest.mark.level1 425@pytest.mark.platform_arm_ascend_training 426@pytest.mark.platform_x86_ascend_training 427@pytest.mark.env_onecard 428def test_pynative_hook_outermost_cell_not_change_grad_ascend(): 429 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 430 pynative_hook_outermost_cell_not_change_grad() 431 432@pytest.mark.level0 433@pytest.mark.platform_x86_gpu_training 434@pytest.mark.env_onecard 435def test_pynative_hook_outermost_cell_not_change_grad_gpu(): 436 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 437 pynative_hook_outermost_cell_not_change_grad() 438 439@pytest.mark.level1 440@pytest.mark.platform_arm_ascend_training 441@pytest.mark.platform_x86_ascend_training 442@pytest.mark.env_onecard 443def test_pynative_hook_all_cell_record_grad_ascend(): 444 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 445 pynative_hook_all_cell_record_grad() 446 447@pytest.mark.level0 448@pytest.mark.platform_x86_gpu_training 449@pytest.mark.env_onecard 450def test_pynative_hook_all_cell_record_grad_gpu(): 451 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 452 pynative_hook_all_cell_record_grad() 453 454@pytest.mark.level1 455@pytest.mark.platform_arm_ascend_training 456@pytest.mark.platform_x86_ascend_training 457@pytest.mark.env_onecard 458def test_pynative_hook_mul_change_input_grad_ascend(): 459 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 460 pynative_hook_mul_change_input_grad() 461 462@pytest.mark.level0 463@pytest.mark.platform_x86_gpu_training 464@pytest.mark.env_onecard 465def test_pynative_hook_mul_change_input_grad_gpu(): 466 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 467 pynative_hook_mul_change_input_grad() 468 469@pytest.mark.level1 470@pytest.mark.platform_arm_ascend_training 471@pytest.mark.platform_x86_ascend_training 472@pytest.mark.env_onecard 473def test_pynative_hook_mul2_change_input_grad_ascend(): 474 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 475 pynative_hook_mul2_change_input_grad() 476 477@pytest.mark.level0 478@pytest.mark.platform_x86_gpu_training 479@pytest.mark.env_onecard 480def test_pynative_hook_mul2_change_input_grad_gpu(): 481 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 482 pynative_hook_mul2_change_input_grad() 483 484@pytest.mark.level1 485@pytest.mark.platform_arm_ascend_training 486@pytest.mark.platform_x86_ascend_training 487@pytest.mark.env_onecard 488def test_pynative_hook_outermost_cell_change_grad_ascend(): 489 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 490 pynative_hook_outermost_cell_change_grad() 491 492@pytest.mark.level0 493@pytest.mark.platform_x86_gpu_training 494@pytest.mark.env_onecard 495def test_pynative_hook_outermost_cell_change_grad_gpu(): 496 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 497 pynative_hook_outermost_cell_change_grad() 498 499@pytest.mark.level1 500@pytest.mark.platform_arm_ascend_training 501@pytest.mark.platform_x86_ascend_training 502@pytest.mark.env_onecard 503def test_pynative_hook_outermost_cell_record_grad_ascend(): 504 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 505 pynative_hook_outermost_cell_record_grad() 506 507@pytest.mark.level0 508@pytest.mark.platform_x86_gpu_training 509@pytest.mark.env_onecard 510def test_pynative_hook_outermost_cell_record_grad_gpu(): 511 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 512 pynative_hook_outermost_cell_record_grad() 513 514@pytest.mark.level1 515@pytest.mark.platform_arm_ascend_training 516@pytest.mark.platform_x86_ascend_training 517@pytest.mark.env_onecard 518def test_pynative_hook_bprop_outermost_cell_record_grad_ascend(): 519 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 520 pynative_hook_bprop_outermost_cell_record_grad() 521 522@pytest.mark.level0 523@pytest.mark.platform_x86_gpu_training 524@pytest.mark.env_onecard 525def test_pynative_hook_bprop_outermost_cell_record_grad_gpu(): 526 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 527 pynative_hook_bprop_outermost_cell_record_grad() 528 529@pytest.mark.level1 530@pytest.mark.platform_arm_ascend_training 531@pytest.mark.platform_x86_ascend_training 532@pytest.mark.env_onecard 533def test_pynative_hook_child_cell_record_grad_ascend(): 534 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 535 pynative_hook_child_cell_record_grad() 536 537@pytest.mark.level0 538@pytest.mark.platform_x86_gpu_training 539@pytest.mark.env_onecard 540def test_pynative_hook_child_cell_record_grad_gpu(): 541 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 542 pynative_hook_child_cell_record_grad() 543