1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_pynative_hook_grad """ 16import numpy as np 17import pytest 18import mindspore.nn as nn 19import mindspore.ops.operations as P 20from mindspore import context 21from mindspore.common.tensor import Tensor 22from mindspore.ops.composite import GradOperation 23from mindspore import ops as OP 24from tests.st.pynative.utils import GradOfAllInputs 25 26 27class MetaFactory: 28 def __init__(self): 29 self.device_target = context.get_context('device_target') 30 self.rank_size = None 31 self.device_id = None 32 self.global_rank_id = None 33 34 35class HookBase(MetaFactory): 36 def __init__(self): 37 super().__init__() 38 MetaFactory.__init__(self) 39 self.grad_input_list = [] 40 self.grad_output_list = [] 41 42 def ms_record_hook(self, cell_id, grad_input, grad_output): 43 for grad in grad_input: 44 self.grad_input_list.append(grad) 45 for grad in grad_output: 46 self.grad_output_list.append(grad) 47 48 def ms_change_grad_double_hook(self, cell_id, grad_input, grad_output): 49 y = Tensor(np.array([2.0]).astype(np.float32)) 50 mul = P.Mul() 51 grad = grad_output[0] 52 output = mul(grad, y) 53 return (output,) 54 55 56class FinalNet(nn.Cell, HookBase): 57 def __init__(self): 58 super().__init__() 59 HookBase.__init__(self) 60 self.conv = nn.Conv2d(1, 3, 3) 61 self.relu = nn.ReLU() 62 63 def construct(self, x, flag): 64 if flag: 65 x = self.conv(x) 66 else: 67 x = self.relu(x) 68 return self.relu(x) 69 70 71class MsMul4(nn.Cell): 72 def construct(self, input_mul): 73 out = input_mul * 2 74 return out 75 76 77class MsMul(nn.Cell): 78 def __init__(self): 79 super().__init__() 80 self.mul = P.Mul() 81 82 def construct(self, x, y): 83 x = self.mul(x, y) 84 return x 85 86 87class MsAdd4(nn.Cell): 88 def construct(self, input_add): 89 out = input_add + 4 90 return out 91 92 93class MsOneInputNet(nn.Cell, HookBase): 94 def __init__(self): 95 super().__init__() 96 HookBase.__init__(self) 97 self.add = MsAdd4() 98 self.mul = MsMul4() 99 self.relu = nn.ReLU() 100 101 def construct(self, x): 102 x = self.add(x) 103 x = self.mul(x) 104 out = self.relu(x) 105 return out 106 107 108class MsMultiInputNet(nn.Cell, HookBase): 109 def __init__(self): 110 super().__init__() 111 HookBase.__init__(self) 112 self.mul1 = MsMul() 113 self.mul2 = MsMul4() 114 115 def construct(self, x, y): 116 a = self.mul1(x, y) 117 b = self.mul2(x) 118 output = self.mul1(a, b) 119 return output 120 121 122class MsNetWithParameter(nn.Cell, HookBase): 123 def __init__(self): 124 super().__init__() 125 HookBase.__init__(self) 126 self.conv1 = nn.Conv2d(2, 4, kernel_size=(1, 1), has_bias=True, 127 weight_init=Tensor(np.ones([4, 2, 1, 1]).astype(np.float32)), 128 bias_init=Tensor(np.ones([4]).astype(np.float32))) 129 self.conv2 = nn.Conv2d(4, 8, kernel_size=(1, 1), has_bias=True, 130 weight_init=Tensor(np.ones([8, 4, 1, 1]).astype(np.float32)), 131 bias_init=Tensor(np.ones([8]).astype(np.float32))) 132 133 def construct(self, x): 134 x = self.conv1(x) 135 output = self.conv2(x) 136 return output 137 138 139class MsNetWithCellinCell(nn.Cell, HookBase): 140 def __init__(self): 141 super().__init__() 142 HookBase.__init__(self) 143 self.net1 = MsOneInputNet() 144 self.mul = MsMul4() 145 146 def construct(self, x): 147 x = self.net1(x) 148 output = self.mul(x) 149 return output 150 151 152class MsSingleOpNetWithBprop(nn.Cell, HookBase): 153 def __init__(self): 154 super().__init__() 155 HookBase.__init__(self) 156 self.op = nn.ReLU() 157 158 def construct(self, x): 159 return self.op(x) 160 161 def bprop(self, x, out, dout): 162 y = Tensor(np.array([5.0]).astype(np.float32)) 163 mul = P.Mul() 164 return mul(x, y) 165 166 167class MsNetHasBpropInChild(nn.Cell, HookBase): 168 def __init__(self): 169 super().__init__() 170 HookBase.__init__(self) 171 self.add = MsAdd4() 172 self.bprop_net = MsSingleOpNetWithBprop() 173 174 def construct(self, x): 175 x = self.add(x) 176 return self.bprop_net(x) 177 178 179class MsMultiOpNetWithBprop(nn.Cell, HookBase): 180 def __init__(self): 181 super().__init__() 182 HookBase.__init__(self) 183 self.mul = MsMul4() 184 self.relu = nn.ReLU() 185 186 def construct(self, x): 187 x = self.mul(x) 188 return self.relu(x) 189 190 def bprop(self, x, out, dout): 191 y = Tensor(np.array([5.0]).astype(np.float32)) 192 mul = P.Mul() 193 return mul(x, y) 194 195 196def _count_unequal_element(data_expected, data_me, rtol, atol): 197 assert data_expected.shape == data_me.shape 198 total_count = len(data_expected.flatten()) 199 error = np.abs(data_expected - data_me) 200 greater = np.greater(error, atol + np.abs(data_me)*rtol) 201 loss_count = np.count_nonzero(greater) 202 assert (loss_count/total_count) < rtol,\ 203 "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\ 204 format(data_expected[greater], data_me[greater], error[greater]) 205 206 207def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): 208 if np.any(np.isnan(data_expected)): 209 assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) 210 elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): 211 _count_unequal_element(data_expected, data_me, rtol, atol) 212 else: 213 assert True 214 215 216def pynative_hook_diff_hook(): 217 input_np = np.ones([1, 1, 224, 224]).astype(np.float32) 218 ms_net = FinalNet() 219 ms_net.set_grad() 220 ms_net.conv.register_backward_hook(ms_net.ms_record_hook) 221 ms_net.relu.register_backward_hook(ms_net.ms_change_grad_double_hook) 222 input_ms = Tensor(input_np) 223 out_ms = ms_net(input_ms, Tensor(1)) 224 grad_net = GradOfAllInputs(ms_net) 225 grad_net.set_train() 226 grad_net(input_ms, Tensor(1), out_ms) 227 228 229def pynative_hook_outermost_cell_not_change_grad(): 230 input_np = np.ones([2, 2]).astype(np.float32) 231 232 ms_net = MsOneInputNet() 233 ms_net.set_grad() 234 ms_net.register_backward_hook(ms_net.ms_record_hook) 235 input_ms = Tensor(input_np) 236 out_ms = ms_net(input_ms) 237 grad_net = GradOfAllInputs(ms_net) 238 grad_net.set_train() 239 input_ms_grad = grad_net(input_ms, out_ms) 240 241 #input grad 242 input_torch_grad = np.array([[20, 20], [20, 20]]) 243 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 244 #hook record grad 245 torch_net_grad_output = np.array([[10, 10], [10, 10]]) 246 torch_net_grad_input = np.array([[20, 20], [20, 20]]) 247 allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001) 248 allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001) 249 250 251def pynative_hook_all_cell_record_grad(): 252 input_np = np.ones([2, 2]).astype(np.float32) 253 254 ms_net = MsOneInputNet() 255 ms_net.set_grad() 256 ms_net.mul.register_backward_hook(ms_net.ms_record_hook) 257 ms_net.add.register_backward_hook(ms_net.ms_record_hook) 258 ms_net.relu.register_backward_hook(ms_net.ms_record_hook) 259 input_ms = Tensor(input_np) 260 out_ms = ms_net(input_ms) 261 grad_net = GradOfAllInputs(ms_net) 262 grad_net.set_train() 263 grad_net(input_ms, out_ms) 264 265 torch_net_grad_input0 = np.array([[10, 10], [10, 10]]) 266 torch_net_grad_output0 = np.array([[10, 10], [10, 10]]) 267 torch_net_grad_input1 = np.array([[20, 20], [20, 20]]) 268 torch_net_grad_output1 = np.array([[10, 10], [10, 10]]) 269 allclose_nparray(torch_net_grad_input0, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001) 270 allclose_nparray(torch_net_grad_output0, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001) 271 allclose_nparray(torch_net_grad_input1, ms_net.grad_output_list[1].asnumpy(), 0.001, 0.001) 272 allclose_nparray(torch_net_grad_output1, ms_net.grad_input_list[1].asnumpy(), 0.001, 0.001) 273 274 torch_net_grad_input3 = np.array([[20, 20], [20, 20]]) 275 torch_net_grad_output2 = np.array([[20, 20], [20, 20]]) 276 allclose_nparray(torch_net_grad_input3, ms_net.grad_output_list[2].asnumpy(), 0.001, 0.001) 277 allclose_nparray(torch_net_grad_output2, ms_net.grad_input_list[2].asnumpy(), 0.001, 0.001) 278 279 280def pynative_hook_mul_change_input_grad(): 281 input_np = np.ones([2, 2]).astype(np.float32) 282 283 ms_net = MsOneInputNet() 284 ms_net.set_grad() 285 ms_net.mul.register_backward_hook(ms_net.ms_change_grad_double_hook) 286 input_ms = Tensor(input_np) 287 out_ms = ms_net(input_ms) 288 grad_net = GradOfAllInputs(ms_net) 289 grad_net.set_train() 290 input_ms_grad = grad_net(input_ms, out_ms) 291 292 #input grad 293 input_torch_grad = np.array([[40, 40], [40, 40]]) 294 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 295 296 297def pynative_hook_mul2_change_input_grad(): 298 input1_np = np.array([2.0, 3.0, 4.0]).astype(np.float32) 299 input2_np = np.array([2.0, 3.0, 4.0]).astype(np.float32) 300 301 ms_net = MsMultiInputNet() 302 ms_net.set_grad() 303 ms_net.mul2.register_backward_hook(ms_net.ms_change_grad_double_hook) 304 input1_ms = Tensor(input1_np) 305 input2_ms = Tensor(input2_np) 306 out_ms = ms_net(input1_ms, input2_ms) 307 grad_net = GradOfAllInputs(ms_net) 308 grad_net.set_train() 309 input_ms_grad = grad_net(input1_ms, input2_ms, out_ms) 310 311 #input grad 312 input1_torch_grad = np.array([384, 2916, 12288]) 313 input2_torch_grad = np.array([128, 972, 4096]) 314 allclose_nparray(input1_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 315 allclose_nparray(input2_torch_grad, input_ms_grad[1].asnumpy(), 0.001, 0.001) 316 317 318def pynative_hook_outermost_cell_change_grad(): 319 input_np = np.ones([2, 2]).astype(np.float32) 320 321 ms_net = MsNetWithCellinCell() 322 ms_net.set_grad() 323 ms_net.register_backward_hook(ms_net.ms_change_grad_double_hook) 324 input_ms = Tensor(input_np) 325 out_ms = ms_net(input_ms) 326 grad_net = GradOfAllInputs(ms_net) 327 grad_net.set_train() 328 input_ms_grad = grad_net(input_ms, out_ms) 329 330 #input grad 331 out_torch = np.array([[20, 20], [20, 20]]) 332 input_torch_grad = np.array([[160, 160], [160, 160]]) 333 allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001) 334 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 335 336 337def pynative_hook_outermost_cell_record_grad(): 338 input_np = np.ones([2, 2]).astype(np.float32) 339 340 ms_net = MsSingleOpNetWithBprop() 341 ms_net.set_grad() 342 ms_net.bprop_debug = True 343 ms_net.register_backward_hook(ms_net.ms_record_hook) 344 input_ms = Tensor(input_np) 345 out_ms = ms_net(input_ms) 346 grad_net = GradOfAllInputs(ms_net) 347 grad_net.set_train() 348 input_ms_grad = grad_net(input_ms, out_ms) 349 350 if ms_net.grad_output_list or ms_net.grad_input_list: 351 assert False 352 353 #input grad 354 out_torch = np.array([[1, 1], [1, 1]]) 355 input_torch_grad = np.array([[5, 5], [5, 5]]) 356 allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001) 357 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 358 359 360def pynative_hook_bprop_outermost_cell_record_grad(): 361 input_np = np.ones([2, 2]).astype(np.float32) 362 363 ms_net = MsNetHasBpropInChild() 364 ms_net.set_grad() 365 ms_net.bprop_net.bprop_debug = True 366 ms_net.register_backward_hook(ms_net.ms_record_hook) 367 input_ms = Tensor(input_np) 368 out_ms = ms_net(input_ms) 369 grad_net = GradOfAllInputs(ms_net) 370 grad_net.set_train() 371 input_ms_grad = grad_net(input_ms, out_ms) 372 373 if len(ms_net.grad_output_list) != len(ms_net.grad_input_list) or not ms_net.grad_output_list: 374 assert False 375 376 #input grad 377 out_torch = np.array([[5, 5], [5, 5]]) 378 input_torch_grad = np.array([[25, 25], [25, 25]]) 379 allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001) 380 allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001) 381 #hook record grad 382 torch_net_grad_output = np.array([[5, 5], [5, 5]]) 383 torch_net_grad_input = np.array([[25, 25], [25, 25]]) 384 allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001) 385 allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001) 386 387 388def pynative_hook_child_cell_record_grad(): 389 input_np = np.ones([2, 2]).astype(np.float32) 390 391 ms_net = MsMultiOpNetWithBprop() 392 ms_net.set_grad() 393 ms_net.bprop_debug = True 394 ms_net.relu.register_backward_hook(ms_net.ms_record_hook) 395 ms_net.mul.register_backward_hook(ms_net.ms_record_hook) 396 input_ms = Tensor(input_np) 397 out_ms = ms_net(input_ms) 398 grad_net = GradOfAllInputs(ms_net) 399 grad_net.set_train() 400 grad_net(input_ms, out_ms) 401 402 if ms_net.grad_output_list or ms_net.grad_input_list: 403 assert False 404 405 406@pytest.mark.level1 407@pytest.mark.platform_arm_ascend_training 408@pytest.mark.platform_x86_ascend_training 409@pytest.mark.env_onecard 410def test_pynative_hook_diff_hook_ascend(): 411 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 412 pynative_hook_diff_hook() 413 414 415@pytest.mark.level0 416@pytest.mark.platform_x86_gpu_training 417@pytest.mark.env_onecard 418def test_pynative_hook_diff_hook_gpu(): 419 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 420 pynative_hook_diff_hook() 421 422 423@pytest.mark.level1 424@pytest.mark.platform_arm_ascend_training 425@pytest.mark.platform_x86_ascend_training 426@pytest.mark.env_onecard 427def test_pynative_hook_outermost_cell_not_change_grad_ascend(): 428 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 429 pynative_hook_outermost_cell_not_change_grad() 430 431 432@pytest.mark.level1 433@pytest.mark.platform_x86_gpu_training 434@pytest.mark.env_onecard 435def test_pynative_hook_outermost_cell_not_change_grad_gpu(): 436 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 437 pynative_hook_outermost_cell_not_change_grad() 438 439 440@pytest.mark.level1 441@pytest.mark.platform_arm_ascend_training 442@pytest.mark.platform_x86_ascend_training 443@pytest.mark.env_onecard 444def test_pynative_hook_all_cell_record_grad_ascend(): 445 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 446 pynative_hook_all_cell_record_grad() 447 448 449@pytest.mark.level1 450@pytest.mark.platform_x86_gpu_training 451@pytest.mark.env_onecard 452def test_pynative_hook_all_cell_record_grad_gpu(): 453 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 454 pynative_hook_all_cell_record_grad() 455 456 457@pytest.mark.level1 458@pytest.mark.platform_arm_ascend_training 459@pytest.mark.platform_x86_ascend_training 460@pytest.mark.env_onecard 461def test_pynative_hook_mul_change_input_grad_ascend(): 462 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 463 pynative_hook_mul_change_input_grad() 464 465 466@pytest.mark.level1 467@pytest.mark.platform_x86_gpu_training 468@pytest.mark.env_onecard 469def test_pynative_hook_mul_change_input_grad_gpu(): 470 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 471 pynative_hook_mul_change_input_grad() 472 473 474@pytest.mark.level1 475@pytest.mark.platform_arm_ascend_training 476@pytest.mark.platform_x86_ascend_training 477@pytest.mark.env_onecard 478def test_pynative_hook_mul2_change_input_grad_ascend(): 479 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 480 pynative_hook_mul2_change_input_grad() 481 482 483@pytest.mark.level1 484@pytest.mark.platform_x86_gpu_training 485@pytest.mark.env_onecard 486def test_pynative_hook_mul2_change_input_grad_gpu(): 487 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 488 pynative_hook_mul2_change_input_grad() 489 490 491@pytest.mark.level1 492@pytest.mark.platform_arm_ascend_training 493@pytest.mark.platform_x86_ascend_training 494@pytest.mark.env_onecard 495def test_pynative_hook_outermost_cell_change_grad_ascend(): 496 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 497 pynative_hook_outermost_cell_change_grad() 498 499 500@pytest.mark.level1 501@pytest.mark.platform_x86_gpu_training 502@pytest.mark.env_onecard 503def test_pynative_hook_outermost_cell_change_grad_gpu(): 504 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 505 pynative_hook_outermost_cell_change_grad() 506 507 508@pytest.mark.level1 509@pytest.mark.platform_arm_ascend_training 510@pytest.mark.platform_x86_ascend_training 511@pytest.mark.env_onecard 512def test_pynative_hook_outermost_cell_record_grad_ascend(): 513 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 514 pynative_hook_outermost_cell_record_grad() 515 516 517@pytest.mark.level1 518@pytest.mark.platform_x86_gpu_training 519@pytest.mark.env_onecard 520def test_pynative_hook_outermost_cell_record_grad_gpu(): 521 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 522 pynative_hook_outermost_cell_record_grad() 523 524 525@pytest.mark.level1 526@pytest.mark.platform_arm_ascend_training 527@pytest.mark.platform_x86_ascend_training 528@pytest.mark.env_onecard 529def test_pynative_hook_bprop_outermost_cell_record_grad_ascend(): 530 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 531 pynative_hook_bprop_outermost_cell_record_grad() 532 533 534@pytest.mark.level1 535@pytest.mark.platform_x86_gpu_training 536@pytest.mark.env_onecard 537def test_pynative_hook_bprop_outermost_cell_record_grad_gpu(): 538 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 539 pynative_hook_bprop_outermost_cell_record_grad() 540 541 542@pytest.mark.level1 543@pytest.mark.platform_arm_ascend_training 544@pytest.mark.platform_x86_ascend_training 545@pytest.mark.env_onecard 546def test_pynative_hook_child_cell_record_grad_ascend(): 547 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 548 pynative_hook_child_cell_record_grad() 549 550 551@pytest.mark.level1 552@pytest.mark.platform_x86_gpu_training 553@pytest.mark.env_onecard 554def test_pynative_hook_child_cell_record_grad_gpu(): 555 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 556 pynative_hook_child_cell_record_grad() 557 558 559def backward_hook(cell_id, grad_input, grad_output): 560 """ 561 print backward hook 562 """ 563 print("input: ", grad_input) 564 print("outpt: ", grad_output) 565 return Tensor(np.array([2, 3, 4, 5])).astype(np.float32), Tensor(np.array([5, 6, 7, 8]).astype(np.float32)) 566 567 568class HookNet(nn.Cell): 569 def __init__(self): 570 super(HookNet, self).__init__() 571 self.mul = nn.MatMul() 572 self.relu = nn.ReLU() 573 self.handle = self.mul.register_backward_hook(backward_hook) 574 575 def construct(self, x, y): 576 x = self.mul(x, y) 577 x = self.relu(x) 578 x = x + y 579 return x 580 581 582@pytest.mark.level1 583@pytest.mark.platform_x86_cpu 584@pytest.mark.env_onecard 585def test_backward_hook_normal(): 586 """ 587 Feature: Test hook grad feature 588 Description: test backward hook normal 589 Expectation: Success 590 """ 591 592 context.set_context(mode=context.PYNATIVE_MODE) 593 input_x = Tensor(np.array([1, 2, 3, 4]).astype(np.float32)) 594 input_y = Tensor(np.array([5, 6, 7, 8]).astype(np.float32)) 595 net = HookNet() 596 for _ in range(5): 597 grad_op = GradOperation(get_all=True, get_by_list=False, sens_param=False) 598 grad = grad_op(net)(input_x, input_y) 599 assert np.allclose(grad[0].asnumpy(), Tensor(np.array([2, 3, 4, 5])).astype(np.float32).asnumpy(), 0.001, 0.001) 600 assert np.allclose(grad[1].asnumpy(), Tensor(np.array([6, 7, 8, 9])).astype(np.float32).asnumpy(), 0.001, 0.001) 601 602 603class NetWithSaveGrad(nn.Cell): 604 def __init__(self): 605 super(NetWithSaveGrad, self).__init__() 606 self.dense = nn.Dense(3, 2) 607 608 def construct(self, x): 609 x = self.dense(x) 610 hook = OP.HookBackward(hook_wrapper()) 611 x = hook(x) 612 return x 613 614 615def hook_wrapper(): 616 cnt = 0 617 618 def hook_fn(grad): 619 nonlocal cnt 620 assert cnt == 0 621 cnt = cnt + 1 622 return hook_fn 623 624 625@pytest.mark.level0 626@pytest.mark.platform_x86_cpu 627@pytest.mark.env_onecard 628def test_hookbackward_should_two_zero(): 629 """ 630 Feature: Test hook backward feature 631 Description: test hook need reconstruct grad graph 632 Expectation: Success 633 """ 634 context.set_context(mode=context.PYNATIVE_MODE) 635 data = np.array([0.2, 0.5, 0.2], dtype=np.float32) 636 label = np.array([1, 0], dtype=np.float32) 637 638 net = NetWithSaveGrad() 639 loss_fn = nn.CrossEntropyLoss() 640 641 def forward_fn(data, label): 642 logits = OP.squeeze(net(data)) 643 loss = loss_fn(logits, label) 644 return loss, logits 645 646 grad_fn = OP.grad(forward_fn, grad_position=None, weights=net.trainable_params(), has_aux=True) 647 for _ in range(2): 648 _, _ = grad_fn(OP.unsqueeze(Tensor(data), dim=0), Tensor(label)) 649