• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_pynative_hook_grad """
16import numpy as np
17import pytest
18import mindspore.nn as nn
19import mindspore.ops.operations as P
20from mindspore.nn import Cell
21from mindspore import context
22from mindspore.common.tensor import Tensor
23from mindspore.ops.composite import GradOperation
24from mindspore.common import ParameterTuple
25
26class MetaFactory:
27    def __init__(self):
28        self.device_target = context.get_context('device_target')
29        self.rank_size = None
30        self.device_id = None
31        self.global_rank_id = None
32
33class HookBase(MetaFactory):
34    def __init__(self):
35        super().__init__()
36        MetaFactory.__init__(self)
37        self.grad_input_list = []
38        self.grad_output_list = []
39
40    def ms_record_hook(self, cell_id, grad_input, grad_output):
41        for grad in grad_input:
42            self.grad_input_list.append(grad)
43        for grad in grad_output:
44            self.grad_output_list.append(grad)
45
46    def ms_change_grad_double_hook(self, cell_id, grad_input, grad_output):
47        y = Tensor(np.array([2.0]).astype(np.float32))
48        mul = P.Mul()
49        grad = grad_output[0]
50        output = mul(grad, y)
51        return output
52
53class FinalNet(nn.Cell, HookBase):
54    def __init__(self):
55        super().__init__()
56        HookBase.__init__(self)
57        self.conv = nn.Conv2d(1, 3, 3)
58        self.relu = nn.ReLU()
59
60    def construct(self, x, flag):
61        if flag:
62            x = self.conv(x)
63        else:
64            x = self.relu(x)
65        return self.relu(x)
66
67class _Grad(Cell):
68    def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
69        super().__init__()
70        self.network = network
71        self.grad = grad
72        self.sens_param = self.grad.sens_param
73        self.wrt_params = wrt_params
74        self.real_inputs_count = real_inputs_count
75        if self.wrt_params:
76            self.params = ParameterTuple(self.network.trainable_params())
77
78    def construct(self, *inputs):
79        if self.wrt_params:
80            if self.real_inputs_count is None or self.sens_param is False:
81                return self.grad(self.network, self.params)(*inputs)
82            real_inputs = inputs[:self.real_inputs_count]
83            sense_param_inputs = inputs[self.real_inputs_count:]
84            return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
85        if self.real_inputs_count is None or self.sens_param is False:
86            return self.grad(self.network)(*inputs)
87        real_inputs = inputs[:self.real_inputs_count]
88        sense_param_inputs = inputs[self.real_inputs_count:]
89        return self.grad(self.network)(*real_inputs, sense_param_inputs)
90
91class GradOfAllInputs(_Grad):
92    def __init__(self, network, sens_param=True, real_inputs_count=None):
93        super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
94                         network=network, real_inputs_count=real_inputs_count)
95
96class MsMul4(nn.Cell):
97    def construct(self, input_mul):
98        out = input_mul * 2
99        return out
100
101class MsMul(nn.Cell):
102    def __init__(self):
103        super().__init__()
104        self.mul = P.Mul()
105
106    def construct(self, x, y):
107        x = self.mul(x, y)
108        return x
109
110class MsAdd4(nn.Cell):
111    def construct(self, input_add):
112        out = input_add + 4
113        return out
114
115class MsOneInputNet(nn.Cell, HookBase):
116    def __init__(self):
117        super().__init__()
118        HookBase.__init__(self)
119        self.add = MsAdd4()
120        self.mul = MsMul4()
121        self.relu = nn.ReLU()
122
123    def construct(self, x):
124        x = self.add(x)
125        x = self.mul(x)
126        out = self.relu(x)
127        return out
128
129class MsMultiInputNet(nn.Cell, HookBase):
130    def __init__(self):
131        super().__init__()
132        HookBase.__init__(self)
133        self.mul1 = MsMul()
134        self.mul2 = MsMul4()
135    def construct(self, x, y):
136        a = self.mul1(x, y)
137        b = self.mul2(x)
138        output = self.mul1(a, b)
139        return output
140
141class MsNetWithParameter(nn.Cell, HookBase):
142    def __init__(self):
143        super().__init__()
144        HookBase.__init__(self)
145        self.conv1 = nn.Conv2d(2, 4, kernel_size=(1, 1), has_bias=True,
146                               weight_init=Tensor(np.ones([4, 2, 1, 1]).astype(np.float32)),
147                               bias_init=Tensor(np.ones([4]).astype(np.float32)))
148        self.conv2 = nn.Conv2d(4, 8, kernel_size=(1, 1), has_bias=True,
149                               weight_init=Tensor(np.ones([8, 4, 1, 1]).astype(np.float32)),
150                               bias_init=Tensor(np.ones([8]).astype(np.float32)))
151
152    def construct(self, x):
153        x = self.conv1(x)
154        output = self.conv2(x)
155        return output
156
157class MsNetWithCellinCell(nn.Cell, HookBase):
158    def __init__(self):
159        super().__init__()
160        HookBase.__init__(self)
161        self.net1 = MsOneInputNet()
162        self.mul = MsMul4()
163
164    def construct(self, x):
165        x = self.net1(x)
166        output = self.mul(x)
167        return output
168
169class MsSingleOpNetWithBprop(nn.Cell, HookBase):
170    def __init__(self):
171        super().__init__()
172        HookBase.__init__(self)
173        self.op = nn.ReLU()
174
175    def construct(self, x):
176        return self.op(x)
177
178    def bprop(self, x, out, dout):
179        y = Tensor(np.array([5.0]).astype(np.float32))
180        mul = P.Mul()
181        return mul(x, y)
182
183class MsNetHasBpropInChild(nn.Cell, HookBase):
184    def __init__(self):
185        super().__init__()
186        HookBase.__init__(self)
187        self.add = MsAdd4()
188        self.bprop_net = MsSingleOpNetWithBprop()
189
190    def construct(self, x):
191        x = self.add(x)
192        return self.bprop_net(x)
193
194class MsMultiOpNetWithBprop(nn.Cell, HookBase):
195    def __init__(self):
196        super().__init__()
197        HookBase.__init__(self)
198        self.mul = MsMul4()
199        self.relu = nn.ReLU()
200
201    def construct(self, x):
202        x = self.mul(x)
203        return self.relu(x)
204
205    def bprop(self, x, out, dout):
206        y = Tensor(np.array([5.0]).astype(np.float32))
207        mul = P.Mul()
208        return mul(x, y)
209
210def _count_unequal_element(data_expected, data_me, rtol, atol):
211    assert data_expected.shape == data_me.shape
212    total_count = len(data_expected.flatten())
213    error = np.abs(data_expected - data_me)
214    greater = np.greater(error, atol + np.abs(data_me)*rtol)
215    loss_count = np.count_nonzero(greater)
216    assert (loss_count/total_count) < rtol,\
217        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
218        format(data_expected[greater], data_me[greater], error[greater])
219
220def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
221    if np.any(np.isnan(data_expected)):
222        assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
223    elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
224        _count_unequal_element(data_expected, data_me, rtol, atol)
225    else:
226        assert True
227
228def pynative_hook_diff_hook():
229    input_np = np.ones([1, 1, 224, 224]).astype(np.float32)
230    ms_net = FinalNet()
231    ms_net.set_grad()
232    ms_net.conv.register_backward_hook(ms_net.ms_record_hook)
233    ms_net.relu.register_backward_hook(ms_net.ms_change_grad_double_hook)
234    input_ms = Tensor(input_np)
235    out_ms = ms_net(input_ms, Tensor(1))
236    grad_net = GradOfAllInputs(ms_net)
237    grad_net.set_train()
238    grad_net(input_ms, Tensor(1), out_ms)
239
240def pynative_hook_outermost_cell_not_change_grad():
241    input_np = np.ones([2, 2]).astype(np.float32)
242
243    ms_net = MsOneInputNet()
244    ms_net.set_grad()
245    ms_net.register_backward_hook(ms_net.ms_record_hook)
246    input_ms = Tensor(input_np)
247    out_ms = ms_net(input_ms)
248    grad_net = GradOfAllInputs(ms_net)
249    grad_net.set_train()
250    input_ms_grad = grad_net(input_ms, out_ms)
251
252    #input grad
253    input_torch_grad = np.array([[20, 20], [20, 20]])
254    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
255    #hook record grad
256    torch_net_grad_output = np.array([[10, 10], [10, 10]])
257    torch_net_grad_input = np.array([[20, 20], [20, 20]])
258    allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
259    allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
260
261def pynative_hook_all_cell_record_grad():
262    input_np = np.ones([2, 2]).astype(np.float32)
263
264    ms_net = MsOneInputNet()
265    ms_net.set_grad()
266    ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
267    ms_net.add.register_backward_hook(ms_net.ms_record_hook)
268    ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
269    input_ms = Tensor(input_np)
270    out_ms = ms_net(input_ms)
271    grad_net = GradOfAllInputs(ms_net)
272    grad_net.set_train()
273    grad_net(input_ms, out_ms)
274
275    torch_net_grad_input0 = np.array([[10, 10], [10, 10]])
276    torch_net_grad_output0 = np.array([[10, 10], [10, 10]])
277    torch_net_grad_input1 = np.array([[20, 20], [20, 20]])
278    torch_net_grad_output1 = np.array([[10, 10], [10, 10]])
279    allclose_nparray(torch_net_grad_input0, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
280    allclose_nparray(torch_net_grad_output0, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
281    allclose_nparray(torch_net_grad_input1, ms_net.grad_output_list[1].asnumpy(), 0.001, 0.001)
282    allclose_nparray(torch_net_grad_output1, ms_net.grad_input_list[1].asnumpy(), 0.001, 0.001)
283
284    torch_net_grad_input3 = np.array([[20, 20], [20, 20]])
285    torch_net_grad_output2 = np.array([[20, 20], [20, 20]])
286    allclose_nparray(torch_net_grad_input3, ms_net.grad_output_list[2].asnumpy(), 0.001, 0.001)
287    allclose_nparray(torch_net_grad_output2, ms_net.grad_input_list[2].asnumpy(), 0.001, 0.001)
288
289def pynative_hook_mul_change_input_grad():
290    input_np = np.ones([2, 2]).astype(np.float32)
291
292    ms_net = MsOneInputNet()
293    ms_net.set_grad()
294    ms_net.mul.register_backward_hook(ms_net.ms_change_grad_double_hook)
295    input_ms = Tensor(input_np)
296    out_ms = ms_net(input_ms)
297    grad_net = GradOfAllInputs(ms_net)
298    grad_net.set_train()
299    input_ms_grad = grad_net(input_ms, out_ms)
300
301    #input grad
302    input_torch_grad = np.array([[40, 40], [40, 40]])
303    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
304
305def pynative_hook_mul2_change_input_grad():
306    input1_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
307    input2_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
308
309    ms_net = MsMultiInputNet()
310    ms_net.set_grad()
311    ms_net.mul2.register_backward_hook(ms_net.ms_change_grad_double_hook)
312    input1_ms = Tensor(input1_np)
313    input2_ms = Tensor(input2_np)
314    out_ms = ms_net(input1_ms, input2_ms)
315    grad_net = GradOfAllInputs(ms_net)
316    grad_net.set_train()
317    input_ms_grad = grad_net(input1_ms, input2_ms, out_ms)
318
319    #input grad
320    input1_torch_grad = np.array([384, 2916, 12288])
321    input2_torch_grad = np.array([128, 972, 4096])
322    allclose_nparray(input1_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
323    allclose_nparray(input2_torch_grad, input_ms_grad[1].asnumpy(), 0.001, 0.001)
324
325def pynative_hook_outermost_cell_change_grad():
326    input_np = np.ones([2, 2]).astype(np.float32)
327
328    ms_net = MsNetWithCellinCell()
329    ms_net.set_grad()
330    ms_net.register_backward_hook(ms_net.ms_change_grad_double_hook)
331    input_ms = Tensor(input_np)
332    out_ms = ms_net(input_ms)
333    grad_net = GradOfAllInputs(ms_net)
334    grad_net.set_train()
335    input_ms_grad = grad_net(input_ms, out_ms)
336
337    #input grad
338    out_torch = np.array([[20, 20], [20, 20]])
339    input_torch_grad = np.array([[160, 160], [160, 160]])
340    allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
341    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
342
343def pynative_hook_outermost_cell_record_grad():
344    input_np = np.ones([2, 2]).astype(np.float32)
345
346    ms_net = MsSingleOpNetWithBprop()
347    ms_net.set_grad()
348    ms_net.bprop_debug = True
349    ms_net.register_backward_hook(ms_net.ms_record_hook)
350    input_ms = Tensor(input_np)
351    out_ms = ms_net(input_ms)
352    grad_net = GradOfAllInputs(ms_net)
353    grad_net.set_train()
354    input_ms_grad = grad_net(input_ms, out_ms)
355
356    if ms_net.grad_output_list or ms_net.grad_input_list:
357        assert False
358
359    #input grad
360    out_torch = np.array([[1, 1], [1, 1]])
361    input_torch_grad = np.array([[5, 5], [5, 5]])
362    allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
363    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
364
365def pynative_hook_bprop_outermost_cell_record_grad():
366    input_np = np.ones([2, 2]).astype(np.float32)
367
368    ms_net = MsNetHasBpropInChild()
369    ms_net.set_grad()
370    ms_net.bprop_net.bprop_debug = True
371    ms_net.register_backward_hook(ms_net.ms_record_hook)
372    input_ms = Tensor(input_np)
373    out_ms = ms_net(input_ms)
374    grad_net = GradOfAllInputs(ms_net)
375    grad_net.set_train()
376    input_ms_grad = grad_net(input_ms, out_ms)
377
378    if len(ms_net.grad_output_list) != len(ms_net.grad_input_list) or not ms_net.grad_output_list:
379        assert False
380
381    #input grad
382    out_torch = np.array([[5, 5], [5, 5]])
383    input_torch_grad = np.array([[25, 25], [25, 25]])
384    allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
385    allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
386    #hook record grad
387    torch_net_grad_output = np.array([[5, 5], [5, 5]])
388    torch_net_grad_input = np.array([[25, 25], [25, 25]])
389    allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
390    allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
391
392def pynative_hook_child_cell_record_grad():
393    input_np = np.ones([2, 2]).astype(np.float32)
394
395    ms_net = MsMultiOpNetWithBprop()
396    ms_net.set_grad()
397    ms_net.bprop_debug = True
398    ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
399    ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
400    input_ms = Tensor(input_np)
401    out_ms = ms_net(input_ms)
402    grad_net = GradOfAllInputs(ms_net)
403    grad_net.set_train()
404    grad_net(input_ms, out_ms)
405
406    if ms_net.grad_output_list or ms_net.grad_input_list:
407        assert False
408
409@pytest.mark.level1
410@pytest.mark.platform_arm_ascend_training
411@pytest.mark.platform_x86_ascend_training
412@pytest.mark.env_onecard
413def test_pynative_hook_diff_hook_ascend():
414    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
415    pynative_hook_diff_hook()
416
417@pytest.mark.level0
418@pytest.mark.platform_x86_gpu_training
419@pytest.mark.env_onecard
420def test_pynative_hook_diff_hook_gpu():
421    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
422    pynative_hook_diff_hook()
423
424@pytest.mark.level1
425@pytest.mark.platform_arm_ascend_training
426@pytest.mark.platform_x86_ascend_training
427@pytest.mark.env_onecard
428def test_pynative_hook_outermost_cell_not_change_grad_ascend():
429    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
430    pynative_hook_outermost_cell_not_change_grad()
431
432@pytest.mark.level0
433@pytest.mark.platform_x86_gpu_training
434@pytest.mark.env_onecard
435def test_pynative_hook_outermost_cell_not_change_grad_gpu():
436    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
437    pynative_hook_outermost_cell_not_change_grad()
438
439@pytest.mark.level1
440@pytest.mark.platform_arm_ascend_training
441@pytest.mark.platform_x86_ascend_training
442@pytest.mark.env_onecard
443def test_pynative_hook_all_cell_record_grad_ascend():
444    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
445    pynative_hook_all_cell_record_grad()
446
447@pytest.mark.level0
448@pytest.mark.platform_x86_gpu_training
449@pytest.mark.env_onecard
450def test_pynative_hook_all_cell_record_grad_gpu():
451    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
452    pynative_hook_all_cell_record_grad()
453
454@pytest.mark.level1
455@pytest.mark.platform_arm_ascend_training
456@pytest.mark.platform_x86_ascend_training
457@pytest.mark.env_onecard
458def test_pynative_hook_mul_change_input_grad_ascend():
459    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
460    pynative_hook_mul_change_input_grad()
461
462@pytest.mark.level0
463@pytest.mark.platform_x86_gpu_training
464@pytest.mark.env_onecard
465def test_pynative_hook_mul_change_input_grad_gpu():
466    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
467    pynative_hook_mul_change_input_grad()
468
469@pytest.mark.level1
470@pytest.mark.platform_arm_ascend_training
471@pytest.mark.platform_x86_ascend_training
472@pytest.mark.env_onecard
473def test_pynative_hook_mul2_change_input_grad_ascend():
474    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
475    pynative_hook_mul2_change_input_grad()
476
477@pytest.mark.level0
478@pytest.mark.platform_x86_gpu_training
479@pytest.mark.env_onecard
480def test_pynative_hook_mul2_change_input_grad_gpu():
481    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
482    pynative_hook_mul2_change_input_grad()
483
484@pytest.mark.level1
485@pytest.mark.platform_arm_ascend_training
486@pytest.mark.platform_x86_ascend_training
487@pytest.mark.env_onecard
488def test_pynative_hook_outermost_cell_change_grad_ascend():
489    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
490    pynative_hook_outermost_cell_change_grad()
491
492@pytest.mark.level0
493@pytest.mark.platform_x86_gpu_training
494@pytest.mark.env_onecard
495def test_pynative_hook_outermost_cell_change_grad_gpu():
496    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
497    pynative_hook_outermost_cell_change_grad()
498
499@pytest.mark.level1
500@pytest.mark.platform_arm_ascend_training
501@pytest.mark.platform_x86_ascend_training
502@pytest.mark.env_onecard
503def test_pynative_hook_outermost_cell_record_grad_ascend():
504    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
505    pynative_hook_outermost_cell_record_grad()
506
507@pytest.mark.level0
508@pytest.mark.platform_x86_gpu_training
509@pytest.mark.env_onecard
510def test_pynative_hook_outermost_cell_record_grad_gpu():
511    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
512    pynative_hook_outermost_cell_record_grad()
513
514@pytest.mark.level1
515@pytest.mark.platform_arm_ascend_training
516@pytest.mark.platform_x86_ascend_training
517@pytest.mark.env_onecard
518def test_pynative_hook_bprop_outermost_cell_record_grad_ascend():
519    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
520    pynative_hook_bprop_outermost_cell_record_grad()
521
522@pytest.mark.level0
523@pytest.mark.platform_x86_gpu_training
524@pytest.mark.env_onecard
525def test_pynative_hook_bprop_outermost_cell_record_grad_gpu():
526    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
527    pynative_hook_bprop_outermost_cell_record_grad()
528
529@pytest.mark.level1
530@pytest.mark.platform_arm_ascend_training
531@pytest.mark.platform_x86_ascend_training
532@pytest.mark.env_onecard
533def test_pynative_hook_child_cell_record_grad_ascend():
534    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
535    pynative_hook_child_cell_record_grad()
536
537@pytest.mark.level0
538@pytest.mark.platform_x86_gpu_training
539@pytest.mark.env_onecard
540def test_pynative_hook_child_cell_record_grad_gpu():
541    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
542    pynative_hook_child_cell_record_grad()
543