• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_pynative_hook_grad """
16import numpy as np
17import pytest
18import mindspore.nn as nn
19import mindspore.ops.operations as P
20from mindspore import context
21from mindspore.common.tensor import Tensor
22from mindspore.ops.composite import GradOperation
23from mindspore import ops as OP
24from tests.st.pynative.utils import GradOfAllInputs
25
26
27class MetaFactory:
28    def __init__(self):
29        self.device_target = context.get_context('device_target')
30        self.rank_size = None
31        self.device_id = None
32        self.global_rank_id = None
33
34
35class HookBase(MetaFactory):
36    def __init__(self):
37        super().__init__()
38        MetaFactory.__init__(self)
39        self.grad_input_list = []
40        self.grad_output_list = []
41
42    def ms_record_hook(self, cell_id, grad_input, grad_output):
43        for grad in grad_input:
44            self.grad_input_list.append(grad)
45        for grad in grad_output:
46            self.grad_output_list.append(grad)
47
48    def ms_change_grad_double_hook(self, cell_id, grad_input, grad_output):
49        y = Tensor(np.array([2.0]).astype(np.float32))
50        mul = P.Mul()
51        grad = grad_output[0]
52        output = mul(grad, y)
53        return (output,)
54
55
56class FinalNet(nn.Cell, HookBase):
57    def __init__(self):
58        super().__init__()
59        HookBase.__init__(self)
60        self.conv = nn.Conv2d(1, 3, 3)
61        self.relu = nn.ReLU()
62
63    def construct(self, x, flag):
64        if flag:
65            x = self.conv(x)
66        else:
67            x = self.relu(x)
68        return self.relu(x)
69
70
71class MsMul4(nn.Cell):
72    def construct(self, input_mul):
73        out = input_mul * 2
74        return out
75
76
77class MsMul(nn.Cell):
78    def __init__(self):
79        super().__init__()
80        self.mul = P.Mul()
81
82    def construct(self, x, y):
83        x = self.mul(x, y)
84        return x
85
86
87class MsAdd4(nn.Cell):
88    def construct(self, input_add):
89        out = input_add + 4
90        return out
91
92
93class MsOneInputNet(nn.Cell, HookBase):
94    def __init__(self):
95        super().__init__()
96        HookBase.__init__(self)
97        self.add = MsAdd4()
98        self.mul = MsMul4()
99        self.relu = nn.ReLU()
100
101    def construct(self, x):
102        x = self.add(x)
103        x = self.mul(x)
104        out = self.relu(x)
105        return out
106
107
108class MsMultiInputNet(nn.Cell, HookBase):
109    def __init__(self):
110        super().__init__()
111        HookBase.__init__(self)
112        self.mul1 = MsMul()
113        self.mul2 = MsMul4()
114
115    def construct(self, x, y):
116        a = self.mul1(x, y)
117        b = self.mul2(x)
118        output = self.mul1(a, b)
119        return output
120
121
122class MsNetWithParameter(nn.Cell, HookBase):
123    def __init__(self):
124        super().__init__()
125        HookBase.__init__(self)
126        self.conv1 = nn.Conv2d(2, 4, kernel_size=(1, 1), has_bias=True,
127                               weight_init=Tensor(np.ones([4, 2, 1, 1]).astype(np.float32)),
128                               bias_init=Tensor(np.ones([4]).astype(np.float32)))
129        self.conv2 = nn.Conv2d(4, 8, kernel_size=(1, 1), has_bias=True,
130                               weight_init=Tensor(np.ones([8, 4, 1, 1]).astype(np.float32)),
131                               bias_init=Tensor(np.ones([8]).astype(np.float32)))
132
133    def construct(self, x):
134        x = self.conv1(x)
135        output = self.conv2(x)
136        return output
137
138
139class MsNetWithCellinCell(nn.Cell, HookBase):
140    def __init__(self):
141        super().__init__()
142        HookBase.__init__(self)
143        self.net1 = MsOneInputNet()
144        self.mul = MsMul4()
145
146    def construct(self, x):
147        x = self.net1(x)
148        output = self.mul(x)
149        return output
150
151
152class MsSingleOpNetWithBprop(nn.Cell, HookBase):
153    def __init__(self):
154        super().__init__()
155        HookBase.__init__(self)
156        self.op = nn.ReLU()
157
158    def construct(self, x):
159        return self.op(x)
160
161    def bprop(self, x, out, dout):
162        y = Tensor(np.array([5.0]).astype(np.float32))
163        mul = P.Mul()
164        return mul(x, y)
165
166
167class MsNetHasBpropInChild(nn.Cell, HookBase):
168    def __init__(self):
169        super().__init__()
170        HookBase.__init__(self)
171        self.add = MsAdd4()
172        self.bprop_net = MsSingleOpNetWithBprop()
173
174    def construct(self, x):
175        x = self.add(x)
176        return self.bprop_net(x)
177
178
179class MsMultiOpNetWithBprop(nn.Cell, HookBase):
180    def __init__(self):
181        super().__init__()
182        HookBase.__init__(self)
183        self.mul = MsMul4()
184        self.relu = nn.ReLU()
185
186    def construct(self, x):
187        x = self.mul(x)
188        return self.relu(x)
189
190    def bprop(self, x, out, dout):
191        y = Tensor(np.array([5.0]).astype(np.float32))
192        mul = P.Mul()
193        return mul(x, y)
194
195
196def _count_unequal_element(data_expected, data_me, rtol, atol):
197    assert data_expected.shape == data_me.shape
198    total_count = len(data_expected.flatten())
199    error = np.abs(data_expected - data_me)
200    greater = np.greater(error, atol + np.abs(data_me)*rtol)
201    loss_count = np.count_nonzero(greater)
202    assert (loss_count/total_count) < rtol,\
203        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
204        format(data_expected[greater], data_me[greater], error[greater])
205
206
207def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
208    if np.any(np.isnan(data_expected)):
209        assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
210    elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
211        _count_unequal_element(data_expected, data_me, rtol, atol)
212    else:
213        assert True
214
215
216def pynative_hook_diff_hook():
217    input_np = np.ones([1, 1, 224, 224]).astype(np.float32)
218    ms_net = FinalNet()
219    ms_net.set_grad()
220    ms_net.conv.register_backward_hook(ms_net.ms_record_hook)
221    ms_net.relu.register_backward_hook(ms_net.ms_change_grad_double_hook)
222    input_ms = Tensor(input_np)
223    out_ms = ms_net(input_ms, Tensor(1))
224    grad_net = GradOfAllInputs(ms_net)
225    grad_net.set_train()
226    grad_net(input_ms, Tensor(1), out_ms)
227
228
229def pynative_hook_outermost_cell_not_change_grad():
230    input_np = np.ones([2, 2]).astype(np.float32)
231
232    ms_net = MsOneInputNet()
233    ms_net.set_grad()
234    ms_net.register_backward_hook(ms_net.ms_record_hook)
235    input_ms = Tensor(input_np)
236    out_ms = ms_net(input_ms)
237    grad_net = GradOfAllInputs(ms_net)
238    grad_net.set_train()
239    input_ms_grad = grad_net(input_ms, out_ms)
240
241    #input grad
242    input_torch_grad = np.array([[20, 20], [20, 20]])
243    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
244    #hook record grad
245    torch_net_grad_output = np.array([[10, 10], [10, 10]])
246    torch_net_grad_input = np.array([[20, 20], [20, 20]])
247    allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
248    allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
249
250
251def pynative_hook_all_cell_record_grad():
252    input_np = np.ones([2, 2]).astype(np.float32)
253
254    ms_net = MsOneInputNet()
255    ms_net.set_grad()
256    ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
257    ms_net.add.register_backward_hook(ms_net.ms_record_hook)
258    ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
259    input_ms = Tensor(input_np)
260    out_ms = ms_net(input_ms)
261    grad_net = GradOfAllInputs(ms_net)
262    grad_net.set_train()
263    grad_net(input_ms, out_ms)
264
265    torch_net_grad_input0 = np.array([[10, 10], [10, 10]])
266    torch_net_grad_output0 = np.array([[10, 10], [10, 10]])
267    torch_net_grad_input1 = np.array([[20, 20], [20, 20]])
268    torch_net_grad_output1 = np.array([[10, 10], [10, 10]])
269    allclose_nparray(torch_net_grad_input0, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
270    allclose_nparray(torch_net_grad_output0, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
271    allclose_nparray(torch_net_grad_input1, ms_net.grad_output_list[1].asnumpy(), 0.001, 0.001)
272    allclose_nparray(torch_net_grad_output1, ms_net.grad_input_list[1].asnumpy(), 0.001, 0.001)
273
274    torch_net_grad_input3 = np.array([[20, 20], [20, 20]])
275    torch_net_grad_output2 = np.array([[20, 20], [20, 20]])
276    allclose_nparray(torch_net_grad_input3, ms_net.grad_output_list[2].asnumpy(), 0.001, 0.001)
277    allclose_nparray(torch_net_grad_output2, ms_net.grad_input_list[2].asnumpy(), 0.001, 0.001)
278
279
280def pynative_hook_mul_change_input_grad():
281    input_np = np.ones([2, 2]).astype(np.float32)
282
283    ms_net = MsOneInputNet()
284    ms_net.set_grad()
285    ms_net.mul.register_backward_hook(ms_net.ms_change_grad_double_hook)
286    input_ms = Tensor(input_np)
287    out_ms = ms_net(input_ms)
288    grad_net = GradOfAllInputs(ms_net)
289    grad_net.set_train()
290    input_ms_grad = grad_net(input_ms, out_ms)
291
292    #input grad
293    input_torch_grad = np.array([[40, 40], [40, 40]])
294    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
295
296
297def pynative_hook_mul2_change_input_grad():
298    input1_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
299    input2_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
300
301    ms_net = MsMultiInputNet()
302    ms_net.set_grad()
303    ms_net.mul2.register_backward_hook(ms_net.ms_change_grad_double_hook)
304    input1_ms = Tensor(input1_np)
305    input2_ms = Tensor(input2_np)
306    out_ms = ms_net(input1_ms, input2_ms)
307    grad_net = GradOfAllInputs(ms_net)
308    grad_net.set_train()
309    input_ms_grad = grad_net(input1_ms, input2_ms, out_ms)
310
311    #input grad
312    input1_torch_grad = np.array([384, 2916, 12288])
313    input2_torch_grad = np.array([128, 972, 4096])
314    allclose_nparray(input1_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
315    allclose_nparray(input2_torch_grad, input_ms_grad[1].asnumpy(), 0.001, 0.001)
316
317
318def pynative_hook_outermost_cell_change_grad():
319    input_np = np.ones([2, 2]).astype(np.float32)
320
321    ms_net = MsNetWithCellinCell()
322    ms_net.set_grad()
323    ms_net.register_backward_hook(ms_net.ms_change_grad_double_hook)
324    input_ms = Tensor(input_np)
325    out_ms = ms_net(input_ms)
326    grad_net = GradOfAllInputs(ms_net)
327    grad_net.set_train()
328    input_ms_grad = grad_net(input_ms, out_ms)
329
330    #input grad
331    out_torch = np.array([[20, 20], [20, 20]])
332    input_torch_grad = np.array([[160, 160], [160, 160]])
333    allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
334    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
335
336
337def pynative_hook_outermost_cell_record_grad():
338    input_np = np.ones([2, 2]).astype(np.float32)
339
340    ms_net = MsSingleOpNetWithBprop()
341    ms_net.set_grad()
342    ms_net.bprop_debug = True
343    ms_net.register_backward_hook(ms_net.ms_record_hook)
344    input_ms = Tensor(input_np)
345    out_ms = ms_net(input_ms)
346    grad_net = GradOfAllInputs(ms_net)
347    grad_net.set_train()
348    input_ms_grad = grad_net(input_ms, out_ms)
349
350    if ms_net.grad_output_list or ms_net.grad_input_list:
351        assert False
352
353    #input grad
354    out_torch = np.array([[1, 1], [1, 1]])
355    input_torch_grad = np.array([[5, 5], [5, 5]])
356    allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
357    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
358
359
360def pynative_hook_bprop_outermost_cell_record_grad():
361    input_np = np.ones([2, 2]).astype(np.float32)
362
363    ms_net = MsNetHasBpropInChild()
364    ms_net.set_grad()
365    ms_net.bprop_net.bprop_debug = True
366    ms_net.register_backward_hook(ms_net.ms_record_hook)
367    input_ms = Tensor(input_np)
368    out_ms = ms_net(input_ms)
369    grad_net = GradOfAllInputs(ms_net)
370    grad_net.set_train()
371    input_ms_grad = grad_net(input_ms, out_ms)
372
373    if len(ms_net.grad_output_list) != len(ms_net.grad_input_list) or not ms_net.grad_output_list:
374        assert False
375
376    #input grad
377    out_torch = np.array([[5, 5], [5, 5]])
378    input_torch_grad = np.array([[25, 25], [25, 25]])
379    allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
380    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
381    #hook record grad
382    torch_net_grad_output = np.array([[5, 5], [5, 5]])
383    torch_net_grad_input = np.array([[25, 25], [25, 25]])
384    allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
385    allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
386
387
388def pynative_hook_child_cell_record_grad():
389    input_np = np.ones([2, 2]).astype(np.float32)
390
391    ms_net = MsMultiOpNetWithBprop()
392    ms_net.set_grad()
393    ms_net.bprop_debug = True
394    ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
395    ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
396    input_ms = Tensor(input_np)
397    out_ms = ms_net(input_ms)
398    grad_net = GradOfAllInputs(ms_net)
399    grad_net.set_train()
400    grad_net(input_ms, out_ms)
401
402    if ms_net.grad_output_list or ms_net.grad_input_list:
403        assert False
404
405
406@pytest.mark.level1
407@pytest.mark.platform_arm_ascend_training
408@pytest.mark.platform_x86_ascend_training
409@pytest.mark.env_onecard
410def test_pynative_hook_diff_hook_ascend():
411    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
412    pynative_hook_diff_hook()
413
414
415@pytest.mark.level0
416@pytest.mark.platform_x86_gpu_training
417@pytest.mark.env_onecard
418def test_pynative_hook_diff_hook_gpu():
419    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
420    pynative_hook_diff_hook()
421
422
423@pytest.mark.level1
424@pytest.mark.platform_arm_ascend_training
425@pytest.mark.platform_x86_ascend_training
426@pytest.mark.env_onecard
427def test_pynative_hook_outermost_cell_not_change_grad_ascend():
428    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
429    pynative_hook_outermost_cell_not_change_grad()
430
431
432@pytest.mark.level1
433@pytest.mark.platform_x86_gpu_training
434@pytest.mark.env_onecard
435def test_pynative_hook_outermost_cell_not_change_grad_gpu():
436    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
437    pynative_hook_outermost_cell_not_change_grad()
438
439
440@pytest.mark.level1
441@pytest.mark.platform_arm_ascend_training
442@pytest.mark.platform_x86_ascend_training
443@pytest.mark.env_onecard
444def test_pynative_hook_all_cell_record_grad_ascend():
445    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
446    pynative_hook_all_cell_record_grad()
447
448
449@pytest.mark.level1
450@pytest.mark.platform_x86_gpu_training
451@pytest.mark.env_onecard
452def test_pynative_hook_all_cell_record_grad_gpu():
453    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
454    pynative_hook_all_cell_record_grad()
455
456
457@pytest.mark.level1
458@pytest.mark.platform_arm_ascend_training
459@pytest.mark.platform_x86_ascend_training
460@pytest.mark.env_onecard
461def test_pynative_hook_mul_change_input_grad_ascend():
462    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
463    pynative_hook_mul_change_input_grad()
464
465
466@pytest.mark.level1
467@pytest.mark.platform_x86_gpu_training
468@pytest.mark.env_onecard
469def test_pynative_hook_mul_change_input_grad_gpu():
470    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
471    pynative_hook_mul_change_input_grad()
472
473
474@pytest.mark.level1
475@pytest.mark.platform_arm_ascend_training
476@pytest.mark.platform_x86_ascend_training
477@pytest.mark.env_onecard
478def test_pynative_hook_mul2_change_input_grad_ascend():
479    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
480    pynative_hook_mul2_change_input_grad()
481
482
483@pytest.mark.level1
484@pytest.mark.platform_x86_gpu_training
485@pytest.mark.env_onecard
486def test_pynative_hook_mul2_change_input_grad_gpu():
487    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
488    pynative_hook_mul2_change_input_grad()
489
490
491@pytest.mark.level1
492@pytest.mark.platform_arm_ascend_training
493@pytest.mark.platform_x86_ascend_training
494@pytest.mark.env_onecard
495def test_pynative_hook_outermost_cell_change_grad_ascend():
496    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
497    pynative_hook_outermost_cell_change_grad()
498
499
500@pytest.mark.level1
501@pytest.mark.platform_x86_gpu_training
502@pytest.mark.env_onecard
503def test_pynative_hook_outermost_cell_change_grad_gpu():
504    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
505    pynative_hook_outermost_cell_change_grad()
506
507
508@pytest.mark.level1
509@pytest.mark.platform_arm_ascend_training
510@pytest.mark.platform_x86_ascend_training
511@pytest.mark.env_onecard
512def test_pynative_hook_outermost_cell_record_grad_ascend():
513    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
514    pynative_hook_outermost_cell_record_grad()
515
516
517@pytest.mark.level1
518@pytest.mark.platform_x86_gpu_training
519@pytest.mark.env_onecard
520def test_pynative_hook_outermost_cell_record_grad_gpu():
521    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
522    pynative_hook_outermost_cell_record_grad()
523
524
525@pytest.mark.level1
526@pytest.mark.platform_arm_ascend_training
527@pytest.mark.platform_x86_ascend_training
528@pytest.mark.env_onecard
529def test_pynative_hook_bprop_outermost_cell_record_grad_ascend():
530    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
531    pynative_hook_bprop_outermost_cell_record_grad()
532
533
534@pytest.mark.level1
535@pytest.mark.platform_x86_gpu_training
536@pytest.mark.env_onecard
537def test_pynative_hook_bprop_outermost_cell_record_grad_gpu():
538    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
539    pynative_hook_bprop_outermost_cell_record_grad()
540
541
542@pytest.mark.level1
543@pytest.mark.platform_arm_ascend_training
544@pytest.mark.platform_x86_ascend_training
545@pytest.mark.env_onecard
546def test_pynative_hook_child_cell_record_grad_ascend():
547    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
548    pynative_hook_child_cell_record_grad()
549
550
551@pytest.mark.level1
552@pytest.mark.platform_x86_gpu_training
553@pytest.mark.env_onecard
554def test_pynative_hook_child_cell_record_grad_gpu():
555    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
556    pynative_hook_child_cell_record_grad()
557
558
559def backward_hook(cell_id, grad_input, grad_output):
560    """
561    print backward hook
562    """
563    print("input: ", grad_input)
564    print("outpt: ", grad_output)
565    return Tensor(np.array([2, 3, 4, 5])).astype(np.float32), Tensor(np.array([5, 6, 7, 8]).astype(np.float32))
566
567
568class HookNet(nn.Cell):
569    def __init__(self):
570        super(HookNet, self).__init__()
571        self.mul = nn.MatMul()
572        self.relu = nn.ReLU()
573        self.handle = self.mul.register_backward_hook(backward_hook)
574
575    def construct(self, x, y):
576        x = self.mul(x, y)
577        x = self.relu(x)
578        x = x + y
579        return x
580
581
582@pytest.mark.level1
583@pytest.mark.platform_x86_cpu
584@pytest.mark.env_onecard
585def test_backward_hook_normal():
586    """
587    Feature: Test hook grad feature
588    Description: test backward hook normal
589    Expectation: Success
590    """
591
592    context.set_context(mode=context.PYNATIVE_MODE)
593    input_x = Tensor(np.array([1, 2, 3, 4]).astype(np.float32))
594    input_y = Tensor(np.array([5, 6, 7, 8]).astype(np.float32))
595    net = HookNet()
596    for _ in range(5):
597        grad_op = GradOperation(get_all=True, get_by_list=False, sens_param=False)
598        grad = grad_op(net)(input_x, input_y)
599    assert np.allclose(grad[0].asnumpy(), Tensor(np.array([2, 3, 4, 5])).astype(np.float32).asnumpy(), 0.001, 0.001)
600    assert np.allclose(grad[1].asnumpy(), Tensor(np.array([6, 7, 8, 9])).astype(np.float32).asnumpy(), 0.001, 0.001)
601
602
603class NetWithSaveGrad(nn.Cell):
604    def __init__(self):
605        super(NetWithSaveGrad, self).__init__()
606        self.dense = nn.Dense(3, 2)
607
608    def construct(self, x):
609        x = self.dense(x)
610        hook = OP.HookBackward(hook_wrapper())
611        x = hook(x)
612        return x
613
614
615def hook_wrapper():
616    cnt = 0
617
618    def hook_fn(grad):
619        nonlocal cnt
620        assert cnt == 0
621        cnt = cnt + 1
622    return hook_fn
623
624
625@pytest.mark.level0
626@pytest.mark.platform_x86_cpu
627@pytest.mark.env_onecard
628def test_hookbackward_should_two_zero():
629    """
630    Feature: Test hook backward feature
631    Description: test hook need reconstruct grad graph
632    Expectation: Success
633    """
634    context.set_context(mode=context.PYNATIVE_MODE)
635    data = np.array([0.2, 0.5, 0.2], dtype=np.float32)
636    label = np.array([1, 0], dtype=np.float32)
637
638    net = NetWithSaveGrad()
639    loss_fn = nn.CrossEntropyLoss()
640
641    def forward_fn(data, label):
642        logits = OP.squeeze(net(data))
643        loss = loss_fn(logits, label)
644        return loss, logits
645
646    grad_fn = OP.grad(forward_fn, grad_position=None, weights=net.trainable_params(), has_aux=True)
647    for _ in range(2):
648        _, _ = grad_fn(OP.unsqueeze(Tensor(data), dim=0), Tensor(label))
649