• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_fn_bprop """
16import numpy as np
17import pytest
18
19import mindspore as ms
20import mindspore.common.dtype as mstype
21import mindspore.nn as nn
22from mindspore import Parameter
23from mindspore import context
24from mindspore.common.api import jit
25from mindspore.common.tensor import Tensor
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28from mindspore.ops.functional import vjp
29from mindspore.ops.function.grad.grad_func import custom_vjp
30
31context.set_context(mode=context.GRAPH_MODE)
32
33grad_all = C.GradOperation(get_all=True)
34
35
36@pytest.mark.level0
37@pytest.mark.platform_x86_ascend_training
38@pytest.mark.env_onecard
39def test_custom_vjp_mul_add():
40    """
41    Features: Custom function bprop
42    Description: Get the custom vjp of mul_add function.
43    Expectation: No exception.
44    """
45
46    @custom_vjp
47    def fn(x, y):
48        return 2 * x + y
49
50    def bprop_fn(x, y, out, dout):
51        return 2 * dout, 2 * y
52
53    fn.defbwd(bprop_fn)
54
55    x = Tensor(1, dtype=ms.int32)
56    y = Tensor(2, dtype=ms.int32)
57    v = Tensor(1, dtype=ms.int32)
58    _, grad_fn = vjp(fn, x, y)
59    grads = grad_fn(v)
60    assert grads[0] == Tensor(2, dtype=ms.int32)
61    assert grads[1] == Tensor(4, dtype=ms.int32)
62
63
64@pytest.mark.level1
65@pytest.mark.platform_x86_ascend_training
66@pytest.mark.env_onecard
67def test_custom_vjp_inline_mul_add():
68    """
69    Features: Custom function bprop
70    Description: Get the custom vjp when mul_add function is inline with other function.
71    Expectation: No exception.
72    """
73
74    @custom_vjp
75    def mul_add(x, y):
76        return 2 * x + y
77
78    def bprop_mul_add(x, y, out, dout):
79        return 2 * dout, 2 * y
80
81    mul_add.defbwd(bprop_mul_add)
82
83    @jit
84    def inline_mul_add(x, y):
85        param = 2
86        return mul_add(x, y) + x + param * y
87
88    x = Tensor(1, dtype=ms.int32)
89    y = Tensor(2, dtype=ms.int32)
90    v = Tensor(1, dtype=ms.int32)
91    _, grad_fn = vjp(inline_mul_add, x, y)
92    grads = grad_fn(v)
93    assert grads[0] == Tensor(3, dtype=ms.int32)
94    assert grads[1] == Tensor(6, dtype=ms.int32)
95
96
97@pytest.mark.level1
98@pytest.mark.platform_x86_ascend_training
99@pytest.mark.env_onecard
100def test_custom_vjp_with_no_bprop():
101    """
102    Features: Custom function bprop
103    Description: Get the vjp with no bprop.
104    Expectation: No exception.
105    """
106
107    def with_no_bprop(x, y):
108        return 2 * x + y
109
110    x = Tensor(1, dtype=ms.int32)
111    y = Tensor(2, dtype=ms.int32)
112    v = Tensor(1, dtype=ms.int32)
113    _, grad_fn = vjp(with_no_bprop, x, y)
114    grads = grad_fn(v)
115    assert grads[0] == Tensor(2, dtype=ms.int32)
116    assert grads[1] == Tensor(1, dtype=ms.int32)
117
118
119@pytest.mark.level0
120@pytest.mark.platform_x86_ascend_training
121@pytest.mark.env_onecard
122def test_custom_vjp_bprop_in_fn_2():
123    """
124    Features: Custom function bprop
125    Description: Get the custom vjp when bprop in fn_2.
126    Expectation: No exception.
127    """
128
129    def fn_1(x, y):
130        relu = P.ReLU()
131        return relu(x)
132
133    @custom_vjp
134    def fn_2(x, y):
135        grads = grad_all(fn_1)(x, y)
136        return fn_1(x, y), grads[0], grads[1]
137
138    def bprop_fn_2(x, y, out, dout):
139        grads = grad_all(fn_1)(x, y)
140        return out[1], grads[1]
141
142    fn_2.defbwd(bprop_fn_2)
143
144    @jit
145    def fn_3(x, y):
146        return fn_2(x, y)
147
148    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
149    x = Tensor(np.ones([2, 2]).astype(np.float32))
150    y = Tensor(np.ones([2, 2]).astype(np.float32))
151
152    _, grad_fn = vjp(fn_3, x, y)
153    grads = grad_fn(v, v, v)
154    assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
155    assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
156
157
158@pytest.mark.level1
159@pytest.mark.platform_x86_ascend_training
160@pytest.mark.env_onecard
161def test_custom_vjp_bprop_in_fn3():
162    """
163    Features: Custom function bprop
164    Description: Get the custom vjp when bprop in fn_3.
165    Expectation: No exception.
166    """
167
168    def fn_1(x, y):
169        relu = P.ReLU()
170        return relu(x)
171
172    @custom_vjp
173    def fn_2(x, y):
174        grads = grad_all(fn_1)(x, y)
175        return fn_1(x, y), grads[0], grads[1]
176
177    def bprop_fn_2(x, y, out, dout):
178        grads = grad_all(fn_1)(x, y)
179        return out[1], grads[1]
180
181    fn_2.defbwd(bprop_fn_2)
182
183    @custom_vjp
184    def fn_3(x, y):
185        return fn_2(x, y)
186
187    def bprop_fn_3(x, y, out, dout):
188        return x + y + y + out[0], x + x + y + y + dout[0]
189
190    fn_3.defbwd(bprop_fn_3)
191
192    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
193    x = Tensor(np.ones([2, 2]).astype(np.float32))
194    y = Tensor(np.ones([2, 2]).astype(np.float32))
195    _, grad_fn = vjp(fn_3, x, y)
196    grads = grad_fn(v, v, v)
197    assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all()
198    assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all()
199
200
201@pytest.mark.level1
202@pytest.mark.platform_x86_ascend_training
203@pytest.mark.env_onecard
204def test_custom_vjp_one_input_bprop():
205    """
206    Features: Custom function bprop
207    Description: Get the custom vjp when the function has only one input.
208    Expectation: No exception.
209    """
210
211    def bprop_fn(x, out, dout):
212        return (5 * x,)
213
214    @custom_vjp
215    def fn(x):
216        op = P.ReLU()
217        return op(x)
218
219    fn.defbwd(bprop_fn)
220    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
221    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
222    _, grad_fn = vjp(fn, input1)
223    grads = grad_fn(v)
224    assert (grads[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all()
225
226
227@pytest.mark.level1
228@pytest.mark.platform_x86_ascend_training
229@pytest.mark.env_onecard
230def test_custom_vjp_inline_bprop_two_input():
231    """
232    Features: Custom function bprop
233    Description: Get the custom vjp when the function has two inputs.
234    Expectation: No exception.
235    """
236
237    def fn_1(x, y):
238        return x * y
239
240    @custom_vjp
241    def fn_2(x, y):
242        grads = grad_all(fn_1)(x, y)
243        return fn_1(x, y), grads[0], grads[1]
244
245    def bprop_fn_2(x, y, out, dout):
246        grads = grad_all(fn_1)(x, y)
247        return grads[0] * 2, grads[1] * 2
248
249    fn_2.defbwd(bprop_fn_2)
250
251    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
252    input2 = Tensor(np.ones([2, 2]).astype(np.float32))
253    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
254    _, grad_fn = vjp(fn_2, input1, input2)
255    grads = grad_fn(v, v, v)
256    assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
257    assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
258    assert len(grads) == 2
259
260
261@pytest.mark.level1
262@pytest.mark.platform_x86_ascend_training
263@pytest.mark.env_onecard
264def test_custom_vjp_inline_bprop_multi_input():
265    """
266    Features: Custom function bprop
267    Description: Get the custom vjp of hybrid bprop function.
268    Expectation: No exception.
269    """
270
271    def tensor_mul(x, y):
272        return x * y
273
274    @custom_vjp
275    def two_input(x, y):
276        op = P.Mul()
277        return op(x, y)
278
279    def two_input_bprop(x, y, out, dout):
280        return 5 * x, 8 * y
281
282    two_input.defbwd(two_input_bprop)
283
284    def two_input_1(x, y):
285        op = P.Mul()
286        x = 1 + x
287        return op(x, y)
288
289    @custom_vjp
290    def two_input_2(x, y):
291        op = P.Mul()
292        return op(x, y)
293
294    def two_input_2_bprop(x, y, out, dout):
295        return 5 * x, 8 * y
296
297    two_input_2.defbwd(two_input_2_bprop)
298
299    def inline_mutil_two_input(x, y):
300        output = (
301            two_input(x, y) + tensor_mul(x, y) + two_input_1(x, y) + two_input_2(x, y)
302        )
303        return output
304
305    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
306    input2 = Tensor(np.ones([2, 2]).astype(np.float32))
307    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
308    _, grad_fn = vjp(inline_mutil_two_input, input1, input2)
309    grads = grad_fn(v)
310    assert (
311        grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)
312    ).all()
313    assert (
314        grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)
315    ).all()
316    assert len(grads) == 2
317
318
319@pytest.mark.level1
320@pytest.mark.platform_x86_cpu
321@pytest.mark.env_onecard
322def test_custom_vjp_fn_with_net():
323    """
324    Features: Custom function bprop
325    Description: Get the custom vjp when the function contains Cell.
326    Expectation: No exception.
327    """
328
329    class Net(nn.Cell):
330        def __init__(self):
331            super(Net, self).__init__()
332            self.matmul = P.MatMul()
333            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name="z")
334
335        def construct(self, x, y):
336            x = x * self.z
337            out = self.matmul(x, y)
338            return out
339
340    def fn_bprop(x, y, out, dout):
341        dx = x + x
342        dy = y + y
343        return dx, dy
344
345    @custom_vjp
346    def fn(x, y):
347        net = Net()
348        return net(x, y)
349
350    fn.defbwd(fn_bprop)
351
352    def grad_net(x, y):
353        grad_f = grad_all(fn)
354        return grad_f(x, y)
355
356    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
357    y = Tensor(
358        [[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32
359    )
360    out = grad_net(x, y)
361    expect_dx = np.array([[1.0, 1.2, 0.8], [2.4, 2.6, 2.2]]).astype(np.float32)
362    expect_dy = np.array([[0.02, 0.6, 2.2], [0.2, 0.4, 2.6], [4.2, 2.4, 6.6]]).astype(
363        np.float32
364    )
365    assert np.allclose(out[0].asnumpy(), expect_dx)
366    assert np.allclose(out[1].asnumpy(), expect_dy)
367
368
369@pytest.mark.level1
370@pytest.mark.platform_x86_cpu
371@pytest.mark.env_onecard
372def test_custom_vjp_forward_net_call_fn():
373    """
374    Feature: Custom function bprop
375    Description: Get the custom vjp when the forward net call the function.
376    Expectation: No exception.
377    """
378
379    class Net1(nn.Cell):
380        def __init__(self):
381            super(Net1, self).__init__()
382            self.matmul = P.MatMul()
383            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name="z")
384
385        def construct(self, x, y):
386            x = x * self.z
387            out = self.matmul(x, y)
388            return out
389
390    @custom_vjp
391    def fn(x, y):
392        net = Net1()
393        return net(x, y)
394
395    def fn_bprop(x, y, out, dout):
396        dx = x + x
397        dy = y + y
398        return dx, dy
399
400    fn.defbwd(fn_bprop)
401
402    class Net(nn.Cell):
403        def construct(self, x, y):
404            return fn(x, y)
405
406    def grad_net(x, y):
407        grad_f = grad_all(Net())
408        return grad_f(x, y)
409
410    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
411    y = Tensor(
412        [[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32
413    )
414    out = grad_net(x, y)
415    expect_dx = np.array([[1.0, 1.2, 0.8], [2.4, 2.6, 2.2]]).astype(np.float32)
416    expect_dy = np.array([[0.02, 0.6, 2.2], [0.2, 0.4, 2.6], [4.2, 2.4, 6.6]]).astype(
417        np.float32
418    )
419    assert np.allclose(out[0].asnumpy(), expect_dx)
420    assert np.allclose(out[1].asnumpy(), expect_dy)
421